def generate_zca_data(): train_file = join(DATA_DIR, "train.npz") with np.load(train_file, "r") as f: train_x = uint8_to_binary_float(f['x']) train_y = f['y'] y_names = f['y_names'] test_file = join(DATA_DIR, "test.npz") with np.load(test_file, "r") as f: test_x = uint8_to_binary_float(f['x']) test_y = f['y'] zca = ZCA(x=None, eps=1e-5) print("Fit train data!") zca.fit(train_x, eps=1e-5) print("Transform train data!") train_x = zca.transform(train_x) print("Transform test data!") test_x = zca.transform(test_x) assert train_x.shape == (50000, 32, 32, 3), "train_x.shape: {}".format(train_x.shape) assert test_x.shape == (10000, 32, 32, 3), "test_x.shape: {}".format(test_x.shape) zca_file = join(ZCA_DATA_DIR, "zca.npz") train_zca_file = join(ZCA_DATA_DIR, "train.npz") test_zca_file = join(ZCA_DATA_DIR, "test.npz") print("Save data!") zca.save_npz(zca_file) np.savez_compressed(train_zca_file, x=train_x, y=train_y, y_names=y_names) np.savez_compressed(test_zca_file, x=test_x, y=test_y, y_names=y_names)
def main(args): # Create output directory # ===================================== # args.output_dir = os.path.join(args.output_dir, args.model_name, args.run) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) else: if args.force_rm_dir: import shutil shutil.rmtree(args.output_dir, ignore_errors=True) print("Removed '{}'".format(args.output_dir)) else: raise ValueError("Output directory '{}' existed. 'force_rm_dir' " "must be set to True!".format(args.output_dir)) os.mkdir(args.output_dir) save_args(os.path.join(args.output_dir, 'config.json'), args) # pp = pprint.PrettyPrinter(indent=4) # pp.pprint(args.__dict__) # ===================================== # # Specify data # ===================================== # if args.dataset == "mnist": x_shape = [28, 28, 1] elif args.dataset == "mnist_3" or args.dataset == "mnistm": x_shape = [28, 28, 3] elif args.dataset == "svhn" or args.dataset == "cifar10" or args.dataset == "cifar100": x_shape = [32, 32, 3] else: raise ValueError("Do not support dataset '{}'!".format(args.dataset)) if args.dataset == "cifar100": num_classes = 100 else: num_classes = 10 print("x_shape: {}".format(x_shape)) print("num_classes: {}".format(num_classes)) # ===================================== # # Load data # ===================================== # print("Loading {}!".format(args.dataset)) train_loader = SimpleDataset4SSL() train_loader.load_npz_data(args.train_file) train_loader.create_ssl_data(args.num_labeled, num_classes=num_classes, shuffle=True, seed=args.seed) if args.input_norm != "applied": train_loader.x = uint8_to_binary_float(train_loader.x) else: print("Input normalization has been applied on train data!") test_loader = SimpleDataset() test_loader.load_npz_data(args.test_file) if args.input_norm != "applied": test_loader.x = uint8_to_binary_float(test_loader.x) else: print("Input normalization has been applied on test data!") print("train_l/train_u/test: {}/{}/{}".format( train_loader.num_labeled_data, train_loader.num_unlabeled_data, test_loader.num_data)) # import matplotlib.pyplot as plt # print("train_l.y[:10]: {}".format(train_l.y[:10])) # print("train_u.y[:10]: {}".format(train_u.y[:10])) # print("test.y[:10]: {}".format(test.y[:10])) # fig, axes = plt.subplots(3, 5) # for i in range(5): # axes[0][i].imshow(train_l.x[i]) # axes[1][i].imshow(train_u.x[i]) # axes[2][i].imshow(test.x[i]) # plt.show() if args.dataset == "mnist": train_loader.x = np.expand_dims(train_loader.x, axis=-1) test_loader.x = np.expand_dims(test_loader.x, axis=-1) elif args.dataset == "mnist_3": train_loader.x = np.stack( [train_loader.x, train_loader.x, train_loader.x], axis=-1) test_loader.x = np.stack([test_loader.x, test_loader.x, test_loader.x], axis=-1) # Data Preprocessing + Augmentation # ------------------------------------- # if args.input_norm == 'none' or args.input_norm == 'applied': print("Do not apply any normalization!") elif args.input_norm == 'zca': print("Apply ZCA whitening on data!") normalizer = ZCA() normalizer.fit(train_loader.x, eps=1e-5) train_loader.x = normalizer.transform(train_loader.x) test_loader.x = normalizer.transform(test_loader.x) elif args.input_norm == 'standard': print("Apply Standardization on data!") normalizer = Standardization() normalizer.fit(train_loader.x) train_loader.x = normalizer.transform(train_loader.x) test_loader.x = normalizer.transform(test_loader.x) else: raise ValueError("Do not support 'input_norm'={}".format( args.input_norm)) # ------------------------------------- # # ===================================== # # Hyperparameters # ===================================== # hyper_updater = HyperParamUpdater( ['lr', 'ema_momentum', 'cent_u_coeff', 'cons_coeff'], [ args.lr_max, args.ema_momentum_init, args.cent_u_coeff_max, args.cons_coeff_max ], scope='moving_hyperparams') # ===================================== # # Build model # ===================================== # # IMPORTANT: Remember to test with No Gaussian Noise print("args.gauss_noise: {}".format(args.gauss_noise)) if args.model_name == "9310gaurav": main_classifier = MainClassifier_9310gaurav( num_classes=num_classes, use_gauss_noise=args.gauss_noise) else: raise ValueError("Do not support model_name='{}'!".format( args.model_name)) # Input Perturber # ------------------------------------- # # Input perturber only performs 'translating_pixels' (Both CIFAR-10 and SVHN) here input_perturber = InputPerturber( normalizer=None, # We do not use normalizer here! flip_horizontally=args.flip_horizontally, flip_vertically=False, # We do not flip images vertically! translating_pixels=args.translating_pixels, noise_std=0.0) # We do not add noise here! # ------------------------------------- # # Main model # ------------------------------------- # model = MeanTeacher(x_shape=x_shape, y_shape=num_classes, main_classifier=main_classifier, input_perturber=input_perturber, cons_mode=args.cons_mode, ema_momentum=hyper_updater.variables['ema_momentum'], cons_4_unlabeled_only=args.cons_4_unlabeled_only, weight_decay=args.weight_decay) loss_coeff_dict = { 'cross_ent_l': args.cross_ent_l, 'cond_ent_u': hyper_updater.variables['cent_u_coeff'], 'cons': hyper_updater.variables['cons_coeff'], } model.build(loss_coeff_dict) SimpleParamPrinter.print_all_params_list(trainable_only=False) # ------------------------------------- # # ===================================== # # Build optimizer # ===================================== # losses = model.get_loss() train_params = model.get_train_params() opt_AE = tf.train.MomentumOptimizer( learning_rate=hyper_updater.variables['lr'], momentum=args.lr_momentum, use_nesterov=True) # Contain both batch norm update and teacher param update update_ops = model.get_all_update_ops() print("update_ops: {}".format(update_ops)) with tf.control_dependencies(update_ops): train_op_AE = opt_AE.minimize(loss=losses['loss'], var_list=train_params['loss']) # ===================================== # # Create directories # ===================================== # asset_dir = make_dir_if_not_exist(os.path.join(args.output_dir, "asset")) img_dir = make_dir_if_not_exist(os.path.join(asset_dir, "img")) log_dir = make_dir_if_not_exist(os.path.join(args.output_dir, "log")) train_log_file = os.path.join(log_dir, "train.log") summary_dir = make_dir_if_not_exist( os.path.join(args.output_dir, "summary_tf")) model_dir = make_dir_if_not_exist(os.path.join(args.output_dir, "model_tf")) # ===================================== # # Create session # ===================================== # config_proto = tf.ConfigProto(allow_soft_placement=True) config_proto.gpu_options.allow_growth = True config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9 sess = tf.Session(config=config_proto) train_helper = SimpleTrainHelper(log_dir=summary_dir, save_dir=model_dir, max_to_keep=args.num_save, max_to_keep_best=args.num_save_best) train_helper.initialize(sess, init_variables=True, create_summary_writer=True) # ===================================== # # Start training # ===================================== # # Summarizer # ------------------------------------- # fetch_keys_AE_l = ['acc_y_l', 'cross_ent_l'] fetch_keys_AE_u = ['acc_y_u', 'cond_ent_u', 'cons'] # To compare between MDL loss and xent+consistency to see whether MDL loss # is a better indicator for generalization compared to xent+consistency or not fetch_keys_AE = fetch_keys_AE_l + fetch_keys_AE_u train_summarizer = ScalarSummarizer([(key, 'mean') for key in fetch_keys_AE]) fetch_keys_test = ['acc_y', 'acc_y_stu'] eval_summarizer = ScalarSummarizer([(key, 'mean') for key in fetch_keys_test]) # ------------------------------------- # # Data sampler # ------------------------------------- # # The number of labeled data varies during training if args.batch_size_labeled <= 0: sampler = ContinuousIndexSampler(train_loader.num_data, args.batch_size, shuffle=True) sampling_separately = False print("batch_size_l, batch_size_u vary but their sum={}!".format( args.batch_size)) elif 0 < args.batch_size_labeled < args.batch_size: batch_size_l = args.batch_size_labeled batch_size_u = args.batch_size - args.batch_size_labeled print("batch_size_l/batch_size_u: {}/{}".format( batch_size_l, batch_size_u)) # IMPORTANT: Here we must use 'train_loader.labeled_ids' and 'train_loader.unlabeled_ids', # NOT 'train_loader.num_labeled_data' and 'train_loader.num_unlabeled_data' sampler_l = ContinuousIndexSampler(train_loader.labeled_ids, batch_size_l, shuffle=True) sampler_u = ContinuousIndexSampler(train_loader.unlabeled_ids, batch_size_u, shuffle=True) sampler = ContinuousIndexSamplerGroup(sampler_l, sampler_u) sampling_separately = True else: raise ValueError( "'args.batch_size_labeled' must be in ({}, {})!".format( 0, args.batch_size)) # ------------------------------------- # # Annealer # ------------------------------------- # step_rampup_annealer = StepAnnealing(args.rampup_len_step, value_0=0, value_1=1) sigmoid_rampup_annealer = SigmoidRampup(args.rampup_len_step) sigmoid_rampdown_annealer = SigmoidRampdown(args.rampdown_len_step, args.steps) # ------------------------------------- # # Results Tracker # ------------------------------------- # tracker = BestResultsTracker([('acc_y', 'greater')], num_best=args.num_save_best) # ------------------------------------- # import math batches_per_epoch = int(math.ceil(train_loader.num_data / args.batch_size)) global_step = 0 log_time_start = time() for epoch in range(args.epochs): if global_step >= args.steps: break for batch in range(batches_per_epoch): if global_step >= args.steps: break global_step += 1 # Update hyper parameters # ---------------------------------- # step_rampup = step_rampup_annealer.get_value(global_step) sigmoid_rampup = sigmoid_rampup_annealer.get_value(global_step) sigmoid_rampdown = sigmoid_rampdown_annealer.get_value(global_step) lr = sigmoid_rampup * sigmoid_rampdown * args.lr_max ema_momentum = ( 1.0 - step_rampup ) * args.ema_momentum_init + step_rampup * args.ema_momentum_final cent_u_coeff = sigmoid_rampup * args.cent_u_coeff_max cons_coeff = sigmoid_rampup * args.cons_coeff_max hyper_updater.update(sess, feed_dict={ 'lr': lr, 'ema_momentum': ema_momentum, 'cent_u_coeff': cent_u_coeff, 'cons_coeff': cons_coeff }) hyper_vals = hyper_updater.get_value(sess) hyper_vals['sigmoid_rampup'] = sigmoid_rampup hyper_vals['sigmoid_rampdown'] = sigmoid_rampdown hyper_vals['step_rampup'] = step_rampup # ---------------------------------- # # Train model # ---------------------------------- # if sampling_separately: # print("Sample separately!") batch_ids_l, batch_ids_u = sampler.sample_group_of_ids() xl, yl, label_flag_l = train_loader.fetch_batch(batch_ids_l) xu, yu, label_flag_u = train_loader.fetch_batch(batch_ids_u) assert np.all(label_flag_l), "'label_flag_l: {}'".format( label_flag_l) assert not np.any(label_flag_u), "'label_flag_u: {}'".format( label_flag_u) x = np.concatenate([xl, xu], axis=0) y = np.concatenate([yl, yu], axis=0) label_flag = np.concatenate([label_flag_l, label_flag_u], axis=0) else: # print("Sample jointly!") batch_ids = sampler.sample_ids() x, y, label_flag = train_loader.fetch_batch(batch_ids) _, AEm = sess.run( [train_op_AE, model.get_output(fetch_keys_AE, as_dict=True)], feed_dict={ model.is_train: True, model.x_ph: x, model.y_ph: y, model.label_flag_ph: label_flag }) batch_results = AEm train_summarizer.accumulate(batch_results, args.batch_size) # ---------------------------------- # if global_step % args.save_freq == 0: train_helper.save(sess, global_step) if global_step % args.log_freq == 0: log_time_end = time() log_time_gap = (log_time_end - log_time_start) log_time_start = log_time_end summaries, results = train_summarizer.get_summaries_and_reset( summary_prefix='train') train_helper.add_summaries(summaries, global_step) train_helper.add_summaries( custom_tf_scalar_summaries(hyper_vals, prefix="moving_hyper"), global_step) log_str = "\n[MeanTeacher ({})/{}, {}], " \ "Epoch {}/{}, Batch {}/{} Step {} ({:.2f}s) (train)".format( args.dataset, args.model_name, args.run, epoch, args.epochs, batch, batches_per_epoch, global_step-1, log_time_gap) + \ "\n" + ", ".join(["{}: {:.4f}".format(key, results[key]) for key in fetch_keys_AE_l]) + \ "\n" + ", ".join(["{}: {:.4f}".format(key, results[key]) for key in fetch_keys_AE_u]) + \ "\n" + ", ".join(["{}: {:.4f}".format(key, hyper_vals[key]) for key in hyper_vals]) print(log_str) with open(train_log_file, "a") as f: f.write(log_str) f.write("\n") f.close() if global_step % args.eval_freq == 0: for batch_ids in iterate_data(test_loader.num_data, args.batch_size, shuffle=False, include_remaining=True): x, y = test_loader.fetch_batch(batch_ids) batch_results = sess.run(model.get_output(fetch_keys_test, as_dict=True), feed_dict={ model.is_train: False, model.x_ph: x, model.y_ph: y }) eval_summarizer.accumulate(batch_results, len(batch_ids)) summaries, results = eval_summarizer.get_summaries_and_reset( summary_prefix='test') train_helper.add_summaries(summaries, global_step) log_str = "Epoch {}/{}, Batch {}/{} (test), acc_y: {:.4f}, acc_y_stu: {:.4f}".format( epoch, args.epochs, batch, batches_per_epoch, results['acc_y'], results['acc_y_stu']) print(log_str) with open(train_log_file, "a") as f: f.write(log_str) f.write("\n") f.close() is_better = tracker.check_and_update(results, global_step) if is_better['acc_y']: train_helper.save_best(sess, global_step=global_step) # Last save train_helper.save(sess, global_step)