Exemplo n.º 1
0
def test_model_with_set_get_weights(
        model: nn.Module,
        testset: Dataset,
        metric: TrainingMetrics,
        config: TrainingConfig,
        rept: int = 1) -> Tuple[np.ndarray, np.ndarray]:
    loss_diff_avg = []
    acc_diff_avg = []
    for _ in range(rept):
        m = NeuralTeleportationModel(model,
                                     input_shape=(config.batch_size, 3, 32,
                                                  32)).to(device)
        w_o, cob_o = m.get_params()
        m.random_teleport()
        w_t, cob_t = m.get_params()

        m.set_params(weights=w_o, cob=cob_o)
        res = test(m, testset, metric, config)
        loss1, acc1 = res['loss'], res['accuracy']

        m.set_params(weights=w_t, cob=cob_t)
        res = test(m, testset, metric, config)
        loss2, acc2 = res['loss'], res['accuracy']

        loss_diff_avg.append(np.abs(loss1 - loss2))
        acc_diff_avg.append(np.abs(acc1 - acc2))

        print("==========================================")
        print("Loss and accuracy diff with set/get was")
        print("Loss diff was: {:.6e}".format(np.abs(loss1 - loss2)))
        print("Acc diff was: {:.6e}".format(np.abs(acc1 - acc2)))
        print("==========================================")

    return np.mean(loss_diff_avg), np.mean(acc_diff_avg)
Exemplo n.º 2
0
def generate_1D_linear_interp(
        model: NeuralTeleportationModel,
        param_o: Tuple[torch.Tensor, torch.Tensor],
        param_t: Tuple[torch.Tensor, torch.Tensor],
        a: torch.Tensor,
        trainset: Dataset,
        valset: Dataset,
        metric: TrainingMetrics,
        config: TrainingConfig,
        checkpoint: dict = None) -> Tuple[list, list, list, list]:
    """
        This is 1-Dimensional Linear Interpolation
        θ(α) = (1−α)θ + αθ′
    """
    loss = []
    loss_v = []
    acc_t = []
    acc_v = []
    w_o, cob_o = param_o
    w_t, cob_t = param_t
    start_at = checkpoint["step"] if checkpoint else 0
    try:
        for step, coord in enumerate(a, start_at):
            # Interpolate the weight from W to T(W),
            # then interpolate the cob for the activation
            # and batchNorm layers only.
            print("step {} of {} - alpha={}".format(step + 1, len(a), coord))
            w = (1 - coord) * w_o + coord * w_t
            cob = (1 - coord) * cob_o + coord * cob_t
            model.set_params(w, cob)
            res = test(model, trainset, metric, config)
            loss.append(res['loss'])
            acc_t.append(res['accuracy'])
            res = test(model, valset, metric, config)
            acc_v.append(res['accuracy'])
            loss_v.append(res['loss'])
    except:
        if not checkpoint:
            checkpoint = {
                'step': step,
                'alpha': a,
                'original_model': param_o,
                'teleported_model': param_t,
                'losses': loss,
                'acc_t': acc_t,
                'acc_v': acc_v,
            }
        else:
            checkpoint['step'] = step
            checkpoint['losses'] = checkpoint['losses'].append(loss)
            checkpoint['acc_t'] = checkpoint['acc_t'].append(acc_t)
            checkpoint['acc_v'] = checkpoint['acc_v'].append(loss)
        torch.save(checkpoint, linterp_checkpoint_file)
        print("A checkpoint was made on step {} of {}".format(step, len(a)))
        # This is to notify the upper level of try/except
        # Since there is no way to know if this is from before teleportation or after teleportation.
        raise

    return loss, acc_t, loss_v, acc_v