Exemple #1
0
def zeroshot_train(t_depth,
                   t_width,
                   t_wght_path,
                   s_depth,
                   s_width,
                   seed=42,
                   savedir=None,
                   dataset='cifar10',
                   sample_per_class=0):

    set_seed(seed)

    train_name = '%s_T-%d-%d_S-%d-%d_seed_%d' % (dataset, t_depth, t_width,
                                                 s_depth, s_width, seed)
    if sample_per_class > 0:
        train_name += "-m%d" % sample_per_class
    log_filename = train_name + '_training_log.csv'

    # save dir
    if not savedir:
        savedir = 'zeroshot_' + train_name
    full_savedir = os.path.join(os.getcwd(), savedir)
    mkdir(full_savedir)

    log_filepath = os.path.join(full_savedir, log_filename)
    logger = CustomizedCSVLogger(log_filepath)

    # Teacher
    teacher = WideResidualNetwork(t_depth,
                                  t_width,
                                  input_shape=Config.input_dim,
                                  dropout_rate=0.0,
                                  output_activations=True,
                                  has_softmax=False)

    teacher.load_weights(t_wght_path)
    teacher.trainable = False

    # Student
    student = WideResidualNetwork(s_depth,
                                  s_width,
                                  input_shape=Config.input_dim,
                                  dropout_rate=0.0,
                                  output_activations=True,
                                  has_softmax=False)

    if sample_per_class > 0:
        s_decay_steps = Config.n_outer_loop * Config.n_s_in_loop + Config.n_outer_loop
    else:
        s_decay_steps = Config.n_outer_loop * Config.n_s_in_loop

    s_optim = Adam(learning_rate=CosineDecay(Config.student_init_lr,
                                             decay_steps=s_decay_steps))
    # ---------------------------------------------------------------------------
    # Generator
    generator = NavieGenerator(input_dim=Config.z_dim)
    g_optim = Adam(learning_rate=CosineDecay(Config.generator_init_lr,
                                             decay_steps=Config.n_outer_loop *
                                             Config.n_g_in_loop))
    # ---------------------------------------------------------------------------
    # Test data
    if dataset == 'cifar10':
        (x_train, y_train_lbl), (x_test, y_test) = get_cifar10_data()
    elif dataset == 'fashion_mnist':
        (x_train, y_train_lbl), (x_test, y_test) = get_fashion_mnist_data()
    else:
        raise ValueError("Only Cifar-10 and Fashion-MNIST supported !!")
    test_data_loader = tf.data.Dataset.from_tensor_slices(
        (x_test, y_test)).batch(200)
    # ---------------------------------------------------------------------------
    # Train data (if using train data)
    train_dataflow = None
    if sample_per_class > 0:
        # sample first
        x_train, y_train_lbl = \
            balance_sampling(x_train, y_train_lbl, data_per_class=sample_per_class)
        datagen = ImageDataGenerator(width_shift_range=4,
                                     height_shift_range=4,
                                     horizontal_flip=True,
                                     vertical_flip=False,
                                     rescale=None,
                                     fill_mode='reflect')
        datagen.fit(x_train)
        y_train = to_categorical(y_train_lbl)
        train_dataflow = datagen.flow(x_train,
                                      y_train,
                                      batch_size=Config.batch_size,
                                      shuffle=True)

    # Generator loss metrics
    g_loss_met = tf.keras.metrics.Mean()

    # Student loss metrics
    s_loss_met = tf.keras.metrics.Mean()

    #
    n_cls_t_pred_metric = tf.keras.metrics.Mean()
    n_cls_s_pred_metric = tf.keras.metrics.Mean()

    max_g_grad_norm_metric = tf.keras.metrics.Mean()
    max_s_grad_norm_metric = tf.keras.metrics.Mean()

    test_data_loader = tf.data.Dataset.from_tensor_slices(
        (x_test, y_test)).batch(200)

    teacher.trainable = False

    # checkpoint
    chkpt_dict = {
        'teacher': teacher,
        'student': student,
        'generator': generator,
        's_optim': s_optim,
        'g_optim': g_optim,
    }
    # Saving checkpoint
    ckpt = tf.train.Checkpoint(**chkpt_dict)
    ckpt_manager = tf.train.CheckpointManager(ckpt,
                                              os.path.join(savedir, 'chpt'),
                                              max_to_keep=2)
    # ==========================================================================
    # if a checkpoint exists, restore the latest checkpoint.
    start_iter = 0
    if ckpt_manager.latest_checkpoint:
        ckpt.restore(ckpt_manager.latest_checkpoint)
        print('Latest checkpoint restored!!')
        with open(os.path.join(savedir, 'chpt', 'iteration'), 'r') as f:
            start_iter = int(f.read())
        logger = CustomizedCSVLogger(log_filepath, append=True)

    for iter_ in range(start_iter, Config.n_outer_loop):
        iter_stime = time.time()

        max_s_grad_norm = 0
        max_g_grad_norm = 0
        # sample from latern space to have an image
        z_val = tf.random.normal([Config.batch_size, Config.z_dim])

        # Generator training
        loss = 0
        for ng in range(Config.n_g_in_loop):
            loss, g_grad_norm = train_gen(generator, g_optim, z_val, teacher,
                                          student)
            max_g_grad_norm = max(max_g_grad_norm, g_grad_norm.numpy())
            g_loss_met(loss)

        # ==========================================================================
        # Student training
        loss = 0
        pseudo_imgs, t_logits, t_acts = prepare_train_student(
            generator, z_val, teacher)
        for ns in range(Config.n_s_in_loop):
            # pseudo_imgs, t_logits, t_acts = prepare_train_student(generator, z_val, teacher)
            loss, s_grad_norm, s_logits = train_student(
                pseudo_imgs, s_optim, t_logits, t_acts, student)
            max_s_grad_norm = max(max_s_grad_norm, s_grad_norm.numpy())

            n_cls_t_pred = len(np.unique(np.argmax(t_logits, axis=-1)))
            n_cls_s_pred = len(np.unique(np.argmax(s_logits, axis=-1)))
            # logging
            s_loss_met(loss)
            n_cls_t_pred_metric(n_cls_t_pred)
            n_cls_s_pred_metric(n_cls_s_pred)
        # ==========================================================================
        # train if provided n samples
        if train_dataflow:
            x_batch_train, y_batch_train = next(train_dataflow)
            t_logits, t_acts = forward(teacher, x_batch_train, training=False)
            loss = train_student_with_labels(student, s_optim, x_batch_train,
                                             t_logits, t_acts, y_batch_train)
        # ==========================================================================

        # --------------------------------------------------------------------
        iter_etime = time.time()
        max_g_grad_norm_metric(max_g_grad_norm)
        max_s_grad_norm_metric(max_s_grad_norm)
        # --------------------------------------------------------------------
        is_last_epoch = (iter_ == Config.n_outer_loop - 1)

        if iter_ != 0 and (iter_ % Config.print_freq == 0 or is_last_epoch):
            n_cls_t_pred_avg = n_cls_t_pred_metric.result().numpy()
            n_cls_s_pred_avg = n_cls_s_pred_metric.result().numpy()
            time_per_epoch = iter_etime - iter_stime

            s_loss = s_loss_met.result().numpy()
            g_loss = g_loss_met.result().numpy()
            max_g_grad_norm_avg = max_g_grad_norm_metric.result().numpy()
            max_s_grad_norm_avg = max_s_grad_norm_metric.result().numpy()

            # build ordered dict
            row_dict = OrderedDict()

            row_dict['time_per_epoch'] = time_per_epoch
            row_dict['epoch'] = iter_
            row_dict['generator_loss'] = g_loss
            row_dict['student_kd_loss'] = s_loss
            row_dict['n_cls_t_pred_avg'] = n_cls_t_pred_avg
            row_dict['n_cls_s_pred_avg'] = n_cls_s_pred_avg
            row_dict['max_g_grad_norm_avg'] = max_g_grad_norm_avg
            row_dict['max_s_grad_norm_avg'] = max_s_grad_norm_avg

            if sample_per_class > 0:
                s_optim_iter = iter_ * (Config.n_s_in_loop + 1)
            else:
                s_optim_iter = iter_ * Config.n_s_in_loop
            row_dict['s_optim_lr'] = s_optim.learning_rate(
                s_optim_iter).numpy()
            row_dict['g_optim_lr'] = g_optim.learning_rate(iter_).numpy()

            pprint.pprint(row_dict)
        # ======================================================================
        if iter_ != 0 and (iter_ % Config.log_freq == 0 or is_last_epoch):
            # calculate acc
            test_accuracy = evaluate(test_data_loader, student).numpy()
            row_dict['test_acc'] = test_accuracy
            logger.log_with_order(row_dict)
            print('Test Accuracy: ', test_accuracy)

            # for check poing
            ckpt_save_path = ckpt_manager.save()
            print('Saving checkpoint for epoch {} at {}'.format(
                iter_ + 1, ckpt_save_path))
            with open(os.path.join(savedir, 'chpt', 'iteration'), 'w') as f:
                f.write(str(iter_ + 1))

            s_loss_met.reset_states()
            g_loss_met.reset_states()
            max_g_grad_norm_metric.reset_states()
            max_s_grad_norm_metric.reset_states()

        if iter_ != 0 and (iter_ % 5000 == 0 or is_last_epoch):
            generator.save_weights(
                join(full_savedir, "generator_i{}.h5".format(iter_)))
            student.save_weights(
                join(full_savedir, "student_i{}.h5".format(iter_)))
Exemple #2
0
    def customTrainLoop(self):
        """
        This is the main function for training and evaluating the model.
        We use the functions defined previously to apply the gradients and test on the validation set.

        @return:
        """

        epochs = self.params_dict['epochs']
        batch_size = self.params_dict['batch_size']

        lr = self.params_dict['lr']

        # Later, whenever we perform an optimization step, we pass in the step.
        learning_rate = tf.keras.optimizers.schedules.PolynomialDecay(lr, epochs*50000/batch_size, 1e-5, .5)
        opt = Adam(learning_rate=learning_rate)

        # Later, whenever we perform an optimization step, we pass in the step.
        learning_rate_ibp = tf.keras.optimizers.schedules.PolynomialDecay(10*lr, epochs*50000/batch_size, 1e-4, .5)
        opt_ibp = Adam(learning_rate =learning_rate_ibp)

        # distinguish between the parameters of the ibp and the weights
        vars_ibp = []
        vars_else = []
        for var in self.model.trainable_variables:
            name = var.name
            if 'sb_t_u_1' in name or 'sb_t_u_2' in name or 'sb_t_pi' in name:
                vars_ibp.append(var)
            else:
                vars_else.append(var)

        # the metrics for the train and test sets
        train_acc_metric = val_acc_metric = self.defineMetric()[0]

        loss_fn = self.defineLoss(0.)
        #tf.keras.losses.CategoricalCrossentropy(from_logits=True)


        # initialize the callbacks
        WS = WeightsSaver(self.params_dict['weight_save_freq'])
        WS.specifyFilePath(self.params_dict['model_path'] + self.params_dict['name'])
        MS = MaskSummary(1)

        # this is for the ensemble models
        Y_train_list = []
        Y_val_list = []

        start = 0

        for k in np.arange(self.params_dict['num_chunks']):
            end = start + int(self.params_dict['M'].shape[1] / self.params_dict['num_chunks'])
            Y_train_list += [self.Y_train[:, start:end]]
            Y_val_list += [self.Y_test[:, start:end]]
            start = end

        Y_train_list = np.dstack(Y_train_list)
        Y_val_list = np.dstack(Y_val_list)
        # Prepare the training dataset.
        train_dataset = tf.data.Dataset.from_tensor_slices((self.data_dict['X_train'], Y_train_list))
        train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

        # Prepare the validation dataset.
        val_dataset = tf.data.Dataset.from_tensor_slices((self.data_dict['X_test'], Y_val_list))
        val_dataset = val_dataset.batch(batch_size)

        best_val_loss = 0.
        num_batches = self.data_dict['X_train'].shape[0]//batch_size
        num_batches_val = self.data_dict['X_test'].shape[0]//batch_size

        # iterate over epochs
        # maybe add an early stopping criteria later
        for epoch in range(epochs):
            print("\nStart of epoch %d" % (epoch,))
            print('Learning Rates: weights', str(opt.learning_rate(opt.iterations).numpy()),
                  'ibp', str(opt_ibp.learning_rate(opt.iterations).numpy()))

            start_time = time.time()

            # Iterate over the batches of the dataset.
            train_accs_all = [0.]*Y_train_list.shape[-1]

            for _, (x_batch_train, y_batch_train) in enumerate(train_dataset):

                try:
                    loss_value, train_accs = train(self.model, loss_fn, opt, opt_ibp,
                                       vars_else, vars_ibp,
                                       train_acc_metric,
                                       x_batch_train, y_batch_train)
                # ok this is weird but this way we can avoid the problem of redifining the function
                # if we just just @tf.function in the train step it throws an error
                except (UnboundLocalError, ValueError):
                    # recreates the decorated function
                    train = tf.function(train_step)
                    # the old graph will be ignored and a new workable graph will be created
                    loss_value, train_accs = train(self.model, loss_fn, opt, opt_ibp,
                                       vars_else, vars_ibp,
                                       train_acc_metric,
                                       x_batch_train, y_batch_train)
                train_accs_all = [train_accs_all[i] + train_accs[i]/num_batches
                                  for i in range(len(train_accs_all))]



            # Display metrics at the end of each epoch.
            #train_acc = train_acc_metric.result()
            print("Training acc over epoch: ", end = '')
            for i in range(len(train_accs_all)):
                if len(train_accs_all)>1:
                    print('chunk %d acc: %.4f' % (i, float(train_accs_all[i])), end = '')
                else:
                    print('%.4f' % (float(train_accs_all[i])), end = '')
            print()

            # Reset training metrics at the end of each epoch
            #train_acc_metric.reset_states()

            # Run a validation loop at the end of each epoch.
            val_accs_all = [0.]*Y_val_list.shape[-1]
            for x_batch_val, y_batch_val in val_dataset:
                try:
                    val_accs = test_step(self.model, val_acc_metric, x_batch_val, y_batch_val)
                except (UnboundLocalError, ValueError):
                    test = tf.function(test_step)
                    val_accs = test(self.model, val_acc_metric, x_batch_val, y_batch_val)
                val_accs_all = [val_accs_all[i] + val_accs[i]/num_batches_val
                                for i in range(len(val_accs_all))]

            #val_acc = val_acc_metric.result()
            #val_acc_metric.reset_states()
            print("Validation acc: ", end = '')
            for i in range(len(val_accs_all)):
                if len(val_accs_all)>1:
                    print('chunk %d acc %.4f' % (i, float(val_accs_all[i]),), end = '')
                else:
                    print('%.4f' % (val_accs_all[i],), end = '')
            print()

            # keep the best model according to the validation loss just to see the difference
            if float(np.mean(val_accs_all))>best_val_loss:
                best_val_loss = float(np.mean(val_accs_all))
                WS.specifyFilePath(self.params_dict['model_path'] + self.params_dict['name']+'_best')
                WS.on_epoch_end(self.model, 0)
                WS.specifyFilePath(self.params_dict['model_path'] + self.params_dict['name'])

            # save the models, print the masks
            WS.on_epoch_end(self.model, 0)
            MS.on_epoch_end(self.model, 1)
            print("Time taken: %.2fs" % (time.time() - start_time))

            print()
            self.saveModel()