def test_model_with_set_get_weights( model: nn.Module, testset: Dataset, metric: TrainingMetrics, config: TrainingConfig, rept: int = 1) -> Tuple[np.ndarray, np.ndarray]: loss_diff_avg = [] acc_diff_avg = [] for _ in range(rept): m = NeuralTeleportationModel(model, input_shape=(config.batch_size, 3, 32, 32)).to(device) w_o, cob_o = m.get_params() m.random_teleport() w_t, cob_t = m.get_params() m.set_params(weights=w_o, cob=cob_o) res = test(m, testset, metric, config) loss1, acc1 = res['loss'], res['accuracy'] m.set_params(weights=w_t, cob=cob_t) res = test(m, testset, metric, config) loss2, acc2 = res['loss'], res['accuracy'] loss_diff_avg.append(np.abs(loss1 - loss2)) acc_diff_avg.append(np.abs(acc1 - acc2)) print("==========================================") print("Loss and accuracy diff with set/get was") print("Loss diff was: {:.6e}".format(np.abs(loss1 - loss2))) print("Acc diff was: {:.6e}".format(np.abs(acc1 - acc2))) print("==========================================") return np.mean(loss_diff_avg), np.mean(acc_diff_avg)
def generate_1D_linear_interp( model: NeuralTeleportationModel, param_o: Tuple[torch.Tensor, torch.Tensor], param_t: Tuple[torch.Tensor, torch.Tensor], a: torch.Tensor, trainset: Dataset, valset: Dataset, metric: TrainingMetrics, config: TrainingConfig, checkpoint: dict = None) -> Tuple[list, list, list, list]: """ This is 1-Dimensional Linear Interpolation θ(α) = (1−α)θ + αθ′ """ loss = [] loss_v = [] acc_t = [] acc_v = [] w_o, cob_o = param_o w_t, cob_t = param_t start_at = checkpoint["step"] if checkpoint else 0 try: for step, coord in enumerate(a, start_at): # Interpolate the weight from W to T(W), # then interpolate the cob for the activation # and batchNorm layers only. print("step {} of {} - alpha={}".format(step + 1, len(a), coord)) w = (1 - coord) * w_o + coord * w_t cob = (1 - coord) * cob_o + coord * cob_t model.set_params(w, cob) res = test(model, trainset, metric, config) loss.append(res['loss']) acc_t.append(res['accuracy']) res = test(model, valset, metric, config) acc_v.append(res['accuracy']) loss_v.append(res['loss']) except: if not checkpoint: checkpoint = { 'step': step, 'alpha': a, 'original_model': param_o, 'teleported_model': param_t, 'losses': loss, 'acc_t': acc_t, 'acc_v': acc_v, } else: checkpoint['step'] = step checkpoint['losses'] = checkpoint['losses'].append(loss) checkpoint['acc_t'] = checkpoint['acc_t'].append(acc_t) checkpoint['acc_v'] = checkpoint['acc_v'].append(loss) torch.save(checkpoint, linterp_checkpoint_file) print("A checkpoint was made on step {} of {}".format(step, len(a))) # This is to notify the upper level of try/except # Since there is no way to know if this is from before teleportation or after teleportation. raise return loss, acc_t, loss_v, acc_v