def zeroshot_train(t_depth, t_width, t_wght_path, s_depth, s_width, seed=42, savedir=None, dataset='cifar10', sample_per_class=0): set_seed(seed) train_name = '%s_T-%d-%d_S-%d-%d_seed_%d' % (dataset, t_depth, t_width, s_depth, s_width, seed) if sample_per_class > 0: train_name += "-m%d" % sample_per_class log_filename = train_name + '_training_log.csv' # save dir if not savedir: savedir = 'zeroshot_' + train_name full_savedir = os.path.join(os.getcwd(), savedir) mkdir(full_savedir) log_filepath = os.path.join(full_savedir, log_filename) logger = CustomizedCSVLogger(log_filepath) # Teacher teacher = WideResidualNetwork(t_depth, t_width, input_shape=Config.input_dim, dropout_rate=0.0, output_activations=True, has_softmax=False) teacher.load_weights(t_wght_path) teacher.trainable = False # Student student = WideResidualNetwork(s_depth, s_width, input_shape=Config.input_dim, dropout_rate=0.0, output_activations=True, has_softmax=False) if sample_per_class > 0: s_decay_steps = Config.n_outer_loop * Config.n_s_in_loop + Config.n_outer_loop else: s_decay_steps = Config.n_outer_loop * Config.n_s_in_loop s_optim = Adam(learning_rate=CosineDecay(Config.student_init_lr, decay_steps=s_decay_steps)) # --------------------------------------------------------------------------- # Generator generator = NavieGenerator(input_dim=Config.z_dim) g_optim = Adam(learning_rate=CosineDecay(Config.generator_init_lr, decay_steps=Config.n_outer_loop * Config.n_g_in_loop)) # --------------------------------------------------------------------------- # Test data if dataset == 'cifar10': (x_train, y_train_lbl), (x_test, y_test) = get_cifar10_data() elif dataset == 'fashion_mnist': (x_train, y_train_lbl), (x_test, y_test) = get_fashion_mnist_data() else: raise ValueError("Only Cifar-10 and Fashion-MNIST supported !!") test_data_loader = tf.data.Dataset.from_tensor_slices( (x_test, y_test)).batch(200) # --------------------------------------------------------------------------- # Train data (if using train data) train_dataflow = None if sample_per_class > 0: # sample first x_train, y_train_lbl = \ balance_sampling(x_train, y_train_lbl, data_per_class=sample_per_class) datagen = ImageDataGenerator(width_shift_range=4, height_shift_range=4, horizontal_flip=True, vertical_flip=False, rescale=None, fill_mode='reflect') datagen.fit(x_train) y_train = to_categorical(y_train_lbl) train_dataflow = datagen.flow(x_train, y_train, batch_size=Config.batch_size, shuffle=True) # Generator loss metrics g_loss_met = tf.keras.metrics.Mean() # Student loss metrics s_loss_met = tf.keras.metrics.Mean() # n_cls_t_pred_metric = tf.keras.metrics.Mean() n_cls_s_pred_metric = tf.keras.metrics.Mean() max_g_grad_norm_metric = tf.keras.metrics.Mean() max_s_grad_norm_metric = tf.keras.metrics.Mean() test_data_loader = tf.data.Dataset.from_tensor_slices( (x_test, y_test)).batch(200) teacher.trainable = False # checkpoint chkpt_dict = { 'teacher': teacher, 'student': student, 'generator': generator, 's_optim': s_optim, 'g_optim': g_optim, } # Saving checkpoint ckpt = tf.train.Checkpoint(**chkpt_dict) ckpt_manager = tf.train.CheckpointManager(ckpt, os.path.join(savedir, 'chpt'), max_to_keep=2) # ========================================================================== # if a checkpoint exists, restore the latest checkpoint. start_iter = 0 if ckpt_manager.latest_checkpoint: ckpt.restore(ckpt_manager.latest_checkpoint) print('Latest checkpoint restored!!') with open(os.path.join(savedir, 'chpt', 'iteration'), 'r') as f: start_iter = int(f.read()) logger = CustomizedCSVLogger(log_filepath, append=True) for iter_ in range(start_iter, Config.n_outer_loop): iter_stime = time.time() max_s_grad_norm = 0 max_g_grad_norm = 0 # sample from latern space to have an image z_val = tf.random.normal([Config.batch_size, Config.z_dim]) # Generator training loss = 0 for ng in range(Config.n_g_in_loop): loss, g_grad_norm = train_gen(generator, g_optim, z_val, teacher, student) max_g_grad_norm = max(max_g_grad_norm, g_grad_norm.numpy()) g_loss_met(loss) # ========================================================================== # Student training loss = 0 pseudo_imgs, t_logits, t_acts = prepare_train_student( generator, z_val, teacher) for ns in range(Config.n_s_in_loop): # pseudo_imgs, t_logits, t_acts = prepare_train_student(generator, z_val, teacher) loss, s_grad_norm, s_logits = train_student( pseudo_imgs, s_optim, t_logits, t_acts, student) max_s_grad_norm = max(max_s_grad_norm, s_grad_norm.numpy()) n_cls_t_pred = len(np.unique(np.argmax(t_logits, axis=-1))) n_cls_s_pred = len(np.unique(np.argmax(s_logits, axis=-1))) # logging s_loss_met(loss) n_cls_t_pred_metric(n_cls_t_pred) n_cls_s_pred_metric(n_cls_s_pred) # ========================================================================== # train if provided n samples if train_dataflow: x_batch_train, y_batch_train = next(train_dataflow) t_logits, t_acts = forward(teacher, x_batch_train, training=False) loss = train_student_with_labels(student, s_optim, x_batch_train, t_logits, t_acts, y_batch_train) # ========================================================================== # -------------------------------------------------------------------- iter_etime = time.time() max_g_grad_norm_metric(max_g_grad_norm) max_s_grad_norm_metric(max_s_grad_norm) # -------------------------------------------------------------------- is_last_epoch = (iter_ == Config.n_outer_loop - 1) if iter_ != 0 and (iter_ % Config.print_freq == 0 or is_last_epoch): n_cls_t_pred_avg = n_cls_t_pred_metric.result().numpy() n_cls_s_pred_avg = n_cls_s_pred_metric.result().numpy() time_per_epoch = iter_etime - iter_stime s_loss = s_loss_met.result().numpy() g_loss = g_loss_met.result().numpy() max_g_grad_norm_avg = max_g_grad_norm_metric.result().numpy() max_s_grad_norm_avg = max_s_grad_norm_metric.result().numpy() # build ordered dict row_dict = OrderedDict() row_dict['time_per_epoch'] = time_per_epoch row_dict['epoch'] = iter_ row_dict['generator_loss'] = g_loss row_dict['student_kd_loss'] = s_loss row_dict['n_cls_t_pred_avg'] = n_cls_t_pred_avg row_dict['n_cls_s_pred_avg'] = n_cls_s_pred_avg row_dict['max_g_grad_norm_avg'] = max_g_grad_norm_avg row_dict['max_s_grad_norm_avg'] = max_s_grad_norm_avg if sample_per_class > 0: s_optim_iter = iter_ * (Config.n_s_in_loop + 1) else: s_optim_iter = iter_ * Config.n_s_in_loop row_dict['s_optim_lr'] = s_optim.learning_rate( s_optim_iter).numpy() row_dict['g_optim_lr'] = g_optim.learning_rate(iter_).numpy() pprint.pprint(row_dict) # ====================================================================== if iter_ != 0 and (iter_ % Config.log_freq == 0 or is_last_epoch): # calculate acc test_accuracy = evaluate(test_data_loader, student).numpy() row_dict['test_acc'] = test_accuracy logger.log_with_order(row_dict) print('Test Accuracy: ', test_accuracy) # for check poing ckpt_save_path = ckpt_manager.save() print('Saving checkpoint for epoch {} at {}'.format( iter_ + 1, ckpt_save_path)) with open(os.path.join(savedir, 'chpt', 'iteration'), 'w') as f: f.write(str(iter_ + 1)) s_loss_met.reset_states() g_loss_met.reset_states() max_g_grad_norm_metric.reset_states() max_s_grad_norm_metric.reset_states() if iter_ != 0 and (iter_ % 5000 == 0 or is_last_epoch): generator.save_weights( join(full_savedir, "generator_i{}.h5".format(iter_))) student.save_weights( join(full_savedir, "student_i{}.h5".format(iter_)))
def customTrainLoop(self): """ This is the main function for training and evaluating the model. We use the functions defined previously to apply the gradients and test on the validation set. @return: """ epochs = self.params_dict['epochs'] batch_size = self.params_dict['batch_size'] lr = self.params_dict['lr'] # Later, whenever we perform an optimization step, we pass in the step. learning_rate = tf.keras.optimizers.schedules.PolynomialDecay(lr, epochs*50000/batch_size, 1e-5, .5) opt = Adam(learning_rate=learning_rate) # Later, whenever we perform an optimization step, we pass in the step. learning_rate_ibp = tf.keras.optimizers.schedules.PolynomialDecay(10*lr, epochs*50000/batch_size, 1e-4, .5) opt_ibp = Adam(learning_rate =learning_rate_ibp) # distinguish between the parameters of the ibp and the weights vars_ibp = [] vars_else = [] for var in self.model.trainable_variables: name = var.name if 'sb_t_u_1' in name or 'sb_t_u_2' in name or 'sb_t_pi' in name: vars_ibp.append(var) else: vars_else.append(var) # the metrics for the train and test sets train_acc_metric = val_acc_metric = self.defineMetric()[0] loss_fn = self.defineLoss(0.) #tf.keras.losses.CategoricalCrossentropy(from_logits=True) # initialize the callbacks WS = WeightsSaver(self.params_dict['weight_save_freq']) WS.specifyFilePath(self.params_dict['model_path'] + self.params_dict['name']) MS = MaskSummary(1) # this is for the ensemble models Y_train_list = [] Y_val_list = [] start = 0 for k in np.arange(self.params_dict['num_chunks']): end = start + int(self.params_dict['M'].shape[1] / self.params_dict['num_chunks']) Y_train_list += [self.Y_train[:, start:end]] Y_val_list += [self.Y_test[:, start:end]] start = end Y_train_list = np.dstack(Y_train_list) Y_val_list = np.dstack(Y_val_list) # Prepare the training dataset. train_dataset = tf.data.Dataset.from_tensor_slices((self.data_dict['X_train'], Y_train_list)) train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size) # Prepare the validation dataset. val_dataset = tf.data.Dataset.from_tensor_slices((self.data_dict['X_test'], Y_val_list)) val_dataset = val_dataset.batch(batch_size) best_val_loss = 0. num_batches = self.data_dict['X_train'].shape[0]//batch_size num_batches_val = self.data_dict['X_test'].shape[0]//batch_size # iterate over epochs # maybe add an early stopping criteria later for epoch in range(epochs): print("\nStart of epoch %d" % (epoch,)) print('Learning Rates: weights', str(opt.learning_rate(opt.iterations).numpy()), 'ibp', str(opt_ibp.learning_rate(opt.iterations).numpy())) start_time = time.time() # Iterate over the batches of the dataset. train_accs_all = [0.]*Y_train_list.shape[-1] for _, (x_batch_train, y_batch_train) in enumerate(train_dataset): try: loss_value, train_accs = train(self.model, loss_fn, opt, opt_ibp, vars_else, vars_ibp, train_acc_metric, x_batch_train, y_batch_train) # ok this is weird but this way we can avoid the problem of redifining the function # if we just just @tf.function in the train step it throws an error except (UnboundLocalError, ValueError): # recreates the decorated function train = tf.function(train_step) # the old graph will be ignored and a new workable graph will be created loss_value, train_accs = train(self.model, loss_fn, opt, opt_ibp, vars_else, vars_ibp, train_acc_metric, x_batch_train, y_batch_train) train_accs_all = [train_accs_all[i] + train_accs[i]/num_batches for i in range(len(train_accs_all))] # Display metrics at the end of each epoch. #train_acc = train_acc_metric.result() print("Training acc over epoch: ", end = '') for i in range(len(train_accs_all)): if len(train_accs_all)>1: print('chunk %d acc: %.4f' % (i, float(train_accs_all[i])), end = '') else: print('%.4f' % (float(train_accs_all[i])), end = '') print() # Reset training metrics at the end of each epoch #train_acc_metric.reset_states() # Run a validation loop at the end of each epoch. val_accs_all = [0.]*Y_val_list.shape[-1] for x_batch_val, y_batch_val in val_dataset: try: val_accs = test_step(self.model, val_acc_metric, x_batch_val, y_batch_val) except (UnboundLocalError, ValueError): test = tf.function(test_step) val_accs = test(self.model, val_acc_metric, x_batch_val, y_batch_val) val_accs_all = [val_accs_all[i] + val_accs[i]/num_batches_val for i in range(len(val_accs_all))] #val_acc = val_acc_metric.result() #val_acc_metric.reset_states() print("Validation acc: ", end = '') for i in range(len(val_accs_all)): if len(val_accs_all)>1: print('chunk %d acc %.4f' % (i, float(val_accs_all[i]),), end = '') else: print('%.4f' % (val_accs_all[i],), end = '') print() # keep the best model according to the validation loss just to see the difference if float(np.mean(val_accs_all))>best_val_loss: best_val_loss = float(np.mean(val_accs_all)) WS.specifyFilePath(self.params_dict['model_path'] + self.params_dict['name']+'_best') WS.on_epoch_end(self.model, 0) WS.specifyFilePath(self.params_dict['model_path'] + self.params_dict['name']) # save the models, print the masks WS.on_epoch_end(self.model, 0) MS.on_epoch_end(self.model, 1) print("Time taken: %.2fs" % (time.time() - start_time)) print() self.saveModel()