Exemplo n.º 1
0
        os.path.join(DATAPATH, "label_num_to_disease_map.json"), "r"
    ) as f:
        labels = json.load(f)

    # Dataset
    # dataset = init_dataset(
    #     os.path.join(DATAPATH, "train_tfrecords"),
    #     is_target=True,
    #     shuffle=True,
    # )
    #
    # ds_train, ds_test = split_dataset(dataset, train_size=TRAIN_SIZE)

    ds_train = init_dataset(
        os.path.join(TFRECORDS_TRAIN_PATH),
        is_target=True,
        shuffle=True,
        augment=True,
    )
    ds_train = ds_train.map(
        input_preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
    ds_train = ds_train.batch(batch_size=BATCH_SIZE, drop_remainder=True)
    ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

    # ValueError: When providing an infinite dataset, you must specify the number
    # of steps to run (if you did not intend to create an infinite dataset,
    # make sure to not call `repeat()` on the dataset).
    # ds_train = ds_train.repeat()

    ds_test = init_dataset(
        os.path.join(TFRECORDS_VAL_PATH),
def main(args):
    cmd_line = " ".join(sys.argv)
    log.info(f"{cmd_line}")
    log.info(f"Working dir: {os.getcwd()}")
    set_all_seeds(args.seed)

    ks = args.ks
    lrs = parse_lr(args.lr)

    if args.perturb is not None or args.same_dir or args.same_sign:
        args.node_multi = 1

    if args.load_student is not None:
        args.num_trial = 1

    d, d_output, train_dataset, eval_dataset = init_dataset(args)

    if args.total_bp_iters > 0 and isinstance(train_dataset, RandomDataset):
        args.num_epoch = args.total_bp_iters / args.random_dataset_size
        if args.num_epoch != int(args.num_epoch):
            raise RuntimeError(
                f"random_dataset_size [{args.random_dataset_size}] cannot devide total_bp_iters [{args.total_bp_iters}]"
            )

        args.num_epoch = int(args.num_epoch)
        log.info(f"#Epoch is now set to {args.num_epoch}")

    # ks = [5, 6, 7, 8]
    # ks = [10, 15, 20, 25]
    # ks = [50, 75, 100, 125]

    # ks = [50, 75, 100, 125]
    log.info(args.pretty())
    log.info(f"ks: {ks}")
    log.info(f"lr: {lrs}")

    if args.d_output > 0:
        d_output = args.d_output

    log.info(f"d_output: {d_output}")

    if not args.use_cnn:
        teacher = Model(d[0],
                        ks,
                        d_output,
                        has_bias=not args.no_bias,
                        has_bn=args.teacher_bn,
                        has_bn_affine=args.teacher_bn_affine,
                        bn_before_relu=args.bn_before_relu,
                        leaky_relu=args.leaky_relu).cuda()

    else:
        teacher = ModelConv(d,
                            ks,
                            d_output,
                            has_bn=args.teacher_bn,
                            bn_before_relu=args.bn_before_relu,
                            leaky_relu=args.leaky_relu).cuda()

    if args.load_teacher is not None:
        log.info("Loading teacher from: " + args.load_teacher)
        checkpoint = torch.load(args.load_teacher)
        teacher.load_state_dict(checkpoint['net'])

        if "inactive_nodes" in checkpoint:
            inactive_nodes = checkpoint["inactive_nodes"]
            masks = checkpoint["masks"]
            ratios = checkpoint["ratios"]
            inactive_nodes2, masks2 = prune(teacher, ratios)

            for m, m2 in zip(masks, masks2):
                if (m - m2).norm() > 1e-3:
                    print(m)
                    print(m2)
                    raise RuntimeError("New mask is not the same as old mask")

            for inactive, inactive2 in zip(inactive_nodes, inactive_nodes2):
                if set(inactive) != set(inactive2):
                    raise RuntimeError(
                        "New inactive set is not the same as old inactive set")

            # Make sure the last layer is normalized.
            # teacher.normalize_last()
            # teacher.final_w.weight.data /= 3
            # teacher.final_w.bias.data /= 3
            active_nodes = [[kk for kk in range(k) if kk not in a]
                            for a, k in zip(inactive_nodes, ks)]
            active_ks = [len(a) for a in active_nodes]
        else:
            active_nodes = None
            active_ks = ks

    else:
        log.info("Init teacher..")
        teacher.init_w(use_sep=not args.no_sep,
                       weight_choices=list(args.weight_choices))
        if args.teacher_strength_decay > 0:
            # Prioritize teacher node.
            teacher.prioritize(args.teacher_strength_decay)

        teacher.normalize()
        log.info("Teacher weights initiailzed randomly...")
        active_nodes = None
        active_ks = ks

    log.info(f"Active ks: {active_ks}")

    if args.load_student is None:
        if not args.use_cnn:
            student = Model(d[0],
                            active_ks,
                            d_output,
                            multi=args.node_multi,
                            has_bias=not args.no_bias,
                            has_bn=args.bn,
                            has_bn_affine=args.bn_affine,
                            bn_before_relu=args.bn_before_relu).cuda()
        else:
            student = ModelConv(d,
                                active_ks,
                                d_output,
                                multi=args.node_multi,
                                has_bn=args.bn,
                                bn_before_relu=args.bn_before_relu).cuda()

        # student can start with smaller norm.
        student.scale(args.student_scale_down)

    # Specify some teacher structure.
    '''
    teacher.w0.weight.data.zero_()
    span = d // ks[0]
    for i in range(ks[0]):
        teacher.w0.weight.data[i, span*i:span*i+span] = 1
    '''

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batchsize,
                                               shuffle=True,
                                               num_workers=4)
    eval_loader = torch.utils.data.DataLoader(eval_dataset,
                                              batch_size=args.eval_batchsize,
                                              shuffle=True,
                                              num_workers=4)

    if args.teacher_bias_tune:
        teacher_tune.tune_teacher(eval_loader, teacher)
    if args.teacher_bias_last_layer_tune:
        teacher_tune.tune_teacher_last_layer(eval_loader, teacher)

    # teacher.w0.bias.data.uniform_(-1, 0)
    # teacher.init_orth()

    # init_w(teacher.w0)
    # init_w(teacher.w1)
    # init_w(teacher.w2)

    # init_w2(teacher.w0, multiplier=args.init_multi)
    # init_w2(teacher.w1, multiplier=args.init_multi)
    # init_w2(teacher.w2, multiplier=args.init_multi)

    all_all_corrs = []

    log.info("=== Start ===")
    std = args.data_std

    stats_op = stats_operator.StatsCollector(teacher, student)

    # Compute Correlation between teacher and student activations.
    stats_op.add_stat(stats_operator.StatsCorr,
                      active_nodes=active_nodes,
                      cnt_thres=0.9)

    if args.cross_entropy:
        stats_op.add_stat(stats_operator.StatsCELoss)

        loss = nn.CrossEntropyLoss().cuda()

        def loss_func(predicted, target):
            _, target_y = target.max(1)
            return loss(predicted, target_y)

    else:
        stats_op.add_stat(stats_operator.StatsL2Loss)
        loss_func = nn.MSELoss().cuda()

    # Duplicate training and testing.
    eval_stats_op = deepcopy(stats_op)
    stats_op.label = "train"
    eval_stats_op.label = "eval"

    stats_op.add_stat(stats_operator.StatsGrad)
    stats_op.add_stat(stats_operator.StatsMemory)

    if args.stats_H:
        eval_stats_op.add_stat(stats_operator.StatsHs)

    # pickle.dump(model2numpy(teacher), open("weights_gt.pickle", "wb"), protocol=2)

    all_stats = []
    for i in range(args.num_trial):
        if args.load_student is None:
            log.info("=== Trial %d, std = %f ===" % (i, std))
            student.reset_parameters()
            # student = copy.deepcopy(student_clone)
            # student.set_teacher_sign(teacher, scale=1)
            if args.perturb is not None:
                student.set_teacher(teacher, args.perturb)
            if args.same_dir:
                student.set_teacher_dir(teacher)
            if args.same_sign:
                student.set_teacher_sign(teacher)

        else:
            log.info(f"Loading student {args.load_student}")
            student = torch.load(args.load_student)

        # init_corrs[-1] = predict_last_order(student, teacher, args)
        # alter_last_layer = predict_last_order(student, teacher, args)

        # import pdb
        # pdb.set_trace()

        stats = optimize(train_loader, eval_loader, teacher, student,
                         loss_func, stats_op, eval_stats_op, args, lrs)
        all_stats.append(stats)

    torch.save(all_stats, "stats.pickle")

    # log.info("Student network")
    # log.info(student.w1.weight)
    # log.info("Teacher network")
    # log.info(teacher.w1.weight)
    log.info(f"Working dir: {os.getcwd()}")
Exemplo n.º 3
0
    image_ds = test_ds.map(parse_image)
    image_ds.prefetch(tf.data.experimental.AUTOTUNE)

    make_submission(model=model, image_ds=image_ds, filename="submission.csv")

    #####
    # BATCH PREDICTION
    #####
    import pickle
    from dataset import init_dataset, input_preprocess
    from config import TFRECORDS_VAL_PATH, TFRECORDS_TRAIN_PATH
    from sklearn.metrics import confusion_matrix, classification_report

    ds_val = init_dataset(
        os.path.join(TFRECORDS_VAL_PATH),
        is_target=True,
        shuffle=False,
        augment=False,
    )
    ds_val = ds_val.map(input_preprocess)
    ds_val = ds_val.batch(batch_size=1, drop_remainder=True)
    ds_val = ds_val.prefetch(tf.data.experimental.AUTOTUNE)

    labels_true, labels_pred, _ = predict_batch(model, ds_val)

    print(classification_report(y_true=labels_true, y_pred=labels_pred))
    print(confusion_matrix(y_true=labels_true, y_pred=labels_pred))

    with open(os.path.join(TEST_MODEL_FOLDER, "report.pickle"), "wb") as f:
        pickle.dump(
            classification_report(y_true=labels_true, y_pred=labels_pred), f)
    with open(os.path.join(TEST_MODEL_FOLDER, "matrix.pickle"), "wb") as f:
def main(args):
    cmd_line = " ".join(sys.argv)
    log.info(f"{cmd_line}")
    log.info(f"Working dir: {os.getcwd()}")
    set_all_seeds(args.seed)

    ks = args.ks
    lrs = eval(args.lr)
    if not isinstance(lrs, dict):
        lrs = {0: lrs}

    if args.perturb is not None or args.same_dir or args.same_sign:
        args.node_multi = 1

    if args.load_student is not None:
        args.num_trial = 1

    d, d_output, train_dataset, eval_dataset = init_dataset(args)

    if args.total_bp_iters > 0 and isinstance(train_dataset, RandomDataset):
        args.num_epoch = args.total_bp_iters / args.random_dataset_size
        if args.num_epoch != int(args.num_epoch):
            raise RuntimeError(
                f"random_dataset_size [{args.random_dataset_size}] cannot devide total_bp_iters [{args.total_bp_iters}]"
            )

        args.num_epoch = int(args.num_epoch)
        log.info(f"#Epoch is now set to {args.num_epoch}")

    # ks = [5, 6, 7, 8]
    # ks = [10, 15, 20, 25]
    # ks = [50, 75, 100, 125]

    # ks = [50, 75, 100, 125]
    log.info(args.pretty())
    log.info(f"ks: {ks}")

    if args.d_output > 0:
        d_output = args.d_output

    log.info(f"d_output: {d_output}")

    if not args.use_cnn:
        teacher = Model(d[0],
                        ks,
                        d_output,
                        has_bias=not args.no_bias,
                        has_bn=args.teacher_bn,
                        has_bn_affine=args.teacher_bn_affine,
                        bn_before_relu=args.bn_before_relu,
                        leaky_relu=args.leaky_relu).cuda()

    else:
        teacher = ModelConv(d,
                            ks,
                            d_output,
                            has_bn=args.teacher_bn,
                            bn_before_relu=args.bn_before_relu,
                            leaky_relu=args.leaky_relu).cuda()

    if args.load_teacher is not None:
        log.info("Loading teacher from: " + args.load_teacher)
        checkpoint = torch.load(args.load_teacher)
        teacher.load_state_dict(checkpoint['net'])

        if "inactive_nodes" in checkpoint:
            inactive_nodes = checkpoint["inactive_nodes"]
            masks = checkpoint["masks"]
            ratios = checkpoint["ratios"]
            inactive_nodes2, masks2 = prune(teacher, ratios)

            for m, m2 in zip(masks, masks2):
                if (m - m2).norm() > 1e-3:
                    print(m)
                    print(m2)
                    raise RuntimeError("New mask is not the same as old mask")

            for inactive, inactive2 in zip(inactive_nodes, inactive_nodes2):
                if set(inactive) != set(inactive2):
                    raise RuntimeError(
                        "New inactive set is not the same as old inactive set")

            # Make sure the last layer is normalized.
            # teacher.normalize_last()
            # teacher.final_w.weight.data /= 3
            # teacher.final_w.bias.data /= 3
            active_nodes = [[kk for kk in range(k) if kk not in a]
                            for a, k in zip(inactive_nodes, ks)]
            active_ks = [len(a) for a in active_nodes]
        else:
            active_nodes = None
            active_ks = ks

    else:
        log.info("Init teacher..`")
        teacher.init_w(use_sep=not args.no_sep)
        if args.teacher_strength_decay > 0:
            # Prioritize teacher node.
            teacher.prioritize(args.teacher_strength_decay)

        teacher.normalize()
        log.info("Teacher weights initiailzed randomly...")
        active_nodes = None
        active_ks = ks

    log.info(f"Active ks: {active_ks}")

    if args.load_student is None:
        if not args.use_cnn:
            student = Model(d[0],
                            active_ks,
                            d_output,
                            multi=args.node_multi,
                            has_bias=not args.no_bias,
                            has_bn=args.bn,
                            has_bn_affine=args.bn_affine,
                            bn_before_relu=args.bn_before_relu).cuda()
        else:
            student = ModelConv(d,
                                active_ks,
                                d_output,
                                multi=args.node_multi,
                                has_bn=args.bn,
                                bn_before_relu=args.bn_before_relu).cuda()

        # student can start with smaller norm.
        student.scale(args.student_scale_down)

    # Specify some teacher structure.
    '''
    teacher.w0.weight.data.zero_()
    span = d // ks[0]
    for i in range(ks[0]):
        teacher.w0.weight.data[i, span*i:span*i+span] = 1
    '''

    if args.cross_entropy:
        # Slower to converge since the information provided from the
        # loss function is not sufficient
        loss = nn.CrossEntropyLoss().cuda()

        def loss_func(y, target):
            values, indices = target.max(1)
            err = loss(y, indices)
            return err
    else:
        loss = nn.MSELoss().cuda()

        def loss_func(y, target):
            return loss(y, target)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batchsize,
                                               shuffle=True,
                                               num_workers=4)
    eval_loader = torch.utils.data.DataLoader(eval_dataset,
                                              batch_size=args.eval_batchsize,
                                              shuffle=True,
                                              num_workers=4)

    if args.teacher_bias_tune:
        # Tune the bias of the teacher so that their activation/inactivation is approximated 0.5/0.5
        for t in range(len(ks)):
            output = concatOutput(eval_loader, args.use_cnn, [teacher])
            estimated_bias = output[0]["post_lins"][t].median(dim=0)[0]
            teacher.ws_linear[t].bias.data[:] -= estimated_bias.cuda()

        # double check
        output = concatOutput(eval_loader, args.use_cnn, [teacher])
        for t in range(len(ks)):
            activate_ratio = (output[0]["post_lins"][t] > 0).float().mean(
                dim=0)
            print(f"{t}: {activate_ratio}")

    # teacher.w0.bias.data.uniform_(-1, 0)
    # teacher.init_orth()

    # init_w(teacher.w0)
    # init_w(teacher.w1)
    # init_w(teacher.w2)

    # init_w2(teacher.w0, multiplier=args.init_multi)
    # init_w2(teacher.w1, multiplier=args.init_multi)
    # init_w2(teacher.w2, multiplier=args.init_multi)

    all_all_corrs = []

    log.info("=== Start ===")
    std = args.data_std

    # pickle.dump(model2numpy(teacher), open("weights_gt.pickle", "wb"), protocol=2)

    all_stats = []
    for i in range(args.num_trial):
        if args.load_student is None:
            log.info("=== Trial %d, std = %f ===" % (i, std))
            student.reset_parameters()
            # student = copy.deepcopy(student_clone)
            # student.set_teacher_sign(teacher, scale=1)
            if args.perturb is not None:
                student.set_teacher(teacher, args.perturb)
            if args.same_dir:
                student.set_teacher_dir(teacher)
            if args.same_sign:
                student.set_teacher_sign(teacher)

        else:
            log.info(f"Loading student {args.load_student}")
            student = torch.load(args.load_student)

        # init_corrs[-1] = predict_last_order(student, teacher, args)
        # alter_last_layer = predict_last_order(student, teacher, args)

        # import pdb
        # pdb.set_trace()

        stats = optimize(train_loader, eval_loader, teacher, student,
                         loss_func, active_nodes, args, lrs)
        all_stats.append(stats)

    torch.save(all_stats, "stats.pickle")

    # log.info("Student network")
    # log.info(student.w1.weight)
    # log.info("Teacher network")
    # log.info(teacher.w1.weight)
    log.info(f"Working dir: {os.getcwd()}")
Exemplo n.º 5
0
def run(cfg):
    '''
    run the training loops
    '''

    # torch gpu setup
    os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
    os.environ['CUDA_VISIBLE_DEVICES'] = str(cfg.gpu_idx)

    # make the exp dir
    os.makedirs(cfg.exp_dir, exist_ok=True)

    # set the seeds
    np.random.seed(cfg.seed)
    torch.manual_seed(cfg.seed)

    # set cudnn to reproducibility mode
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    print("Loading dataset...")
    # setup datasets
    dset_train, dset_val = init_dataset(dataset=cfg.default_opts['dataset'],
                                        **cfg.DATASET, eval_only=cfg.eval_only)

    # init loaders
    trainloader = torch.utils.data.DataLoader(dset_train,
                                              num_workers=cfg.num_workers,
                                              pin_memory=True,
                                              batch_size=cfg.batch_size,
                                              shuffle=True)

    if dset_val is not None:
        valloader = torch.utils.data.DataLoader(dset_val,
                                                num_workers=cfg.num_workers,
                                                pin_memory=True,
                                                batch_size=cfg.batch_size,
                                                shuffle=False)
    else:
        valloader = None

    # test loaders
    eval_vars = None

    # init the model
    model, stats, optimizer_state = init_model(cfg, add_log_vars=eval_vars)
    start_epoch = stats.epoch + 1

    # move model to gpu
    if torch.cuda.is_available():
        model.cuda()

    optimizer, scheduler = init_optimizer(
        model, optimizer_state=optimizer_state, **cfg.SOLVER)

    print("Starting main loop...")
    # If evaluation just run it now and exit
    if cfg.eval_only:
        with stats:
            trainvalidate(cfg, model, stats, 0, valloader,
                          [namedtuple('dummyopt', 'num_iter')(num_iter=1)],
                          True, visdom_env_root=get_visdom_env(cfg),
                          exp_dir=cfg.exp_dir)
        return

    for epoch in range(start_epoch, cfg.SOLVER['max_epochs']):
        with stats:  # automatic new_epoch and plotting at every epoch start

            # train loop
            trainvalidate(cfg, model, stats, epoch, trainloader, optimizer,
                          False, visdom_env_root=get_visdom_env(cfg),
                          exp_dir=cfg.exp_dir)

            if valloader is not None:
                # val loop
                trainvalidate(cfg, model, stats, epoch, valloader,
                              [namedtuple('dummyopt', 'num_iter')(num_iter=1)],
                              True, visdom_env_root=get_visdom_env(cfg),
                              exp_dir=cfg.exp_dir)

            assert stats.epoch == epoch, "inconsistent stats!"

            # delete previous models if required
            if cfg.store_checkpoints_purge > 0:
                for prev_epoch in range(epoch-cfg.store_checkpoints_purge):
                    purge_epoch(cfg.exp_dir, prev_epoch)

            if cfg.store_checkpoints:
                outfile = get_checkpoint(cfg.exp_dir, epoch)
                save_model(model, stats, outfile, optimizer=optimizer)

            for sch in scheduler:
                sch.step()
Exemplo n.º 6
0
def main(args):
    basic_tools.start(args)

    ks = parse_ks(args.ks)
    lrs = parse_lr(args.lr)

    if args.perturb is not None or args.same_dir or args.same_sign:
        args.node_multi = 1

    if args.load_student is not None:
        args.num_trial = 1

    if args.load_dataset_path is not None:
        train_dataset = torch.load(
            os.path.join(args.load_dataset_path, "train_dataset.pth"))
        eval_dataset = torch.load(
            os.path.join(args.load_dataset_path, "eval_dataset.pth"))
        saved = torch.load(
            os.path.join(args.load_dataset_path, "params_dataset.pth"))
        d = saved["d"]
        d_output = saved["d_output"]
    else:
        d, d_output, train_dataset, eval_dataset = init_dataset(args)

    if args.save_dataset:
        print("Saving training dataset")
        torch.save(train_dataset, "train_dataset.pth")
        print("Saving eval dataset")
        torch.save(eval_dataset, "eval_dataset.pth")
        print("Saving dataset params")
        torch.save(dict(d=d, d_output=d_output), "params_dataset.pth")

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batchsize,
                                               shuffle=True,
                                               num_workers=4)
    eval_loader = torch.utils.data.DataLoader(eval_dataset,
                                              batch_size=args.eval_batchsize,
                                              shuffle=True,
                                              num_workers=4)

    if args.total_bp_iters > 0 and isinstance(train_dataset, RandomDataset):
        args.num_epoch = args.total_bp_iters / args.random_dataset_size
        if args.num_epoch != int(args.num_epoch):
            raise RuntimeError(
                f"random_dataset_size [{args.random_dataset_size}] cannot devide total_bp_iters [{args.total_bp_iters}]"
            )

        args.num_epoch = int(args.num_epoch)
        print(f"#Epoch is now set to {args.num_epoch}")

    print(f"ks: {ks}")
    print(f"lr: {lrs}")

    if args.d_output > 0:
        d_output = args.d_output

    print(f"d_output: {d_output}")
    loss_func = initialize_loss_func(args)

    if args.save_train_dataset:
        print("Save training dataset")
        torch.save(train_loader, "train_dataset.pth")

    if args.save_eval_dataset:
        print("Save eval dataset")
        torch.save(eval_loader, "eval_dataset.pth")

    if checkpoint.exist_checkpoint():
        cp = checkpoint.load_checkpoint()
    elif args.resume_from_checkpoint is not None:
        cp = checkpoint.load_checkpoint(filename=args.resume_from_checkpoint)
    else:
        teacher, student = initialize_networks(d, ks, d_output, args)

        if args.load_teacher is None:
            tune_teacher_model(teacher, train_loader, eval_loader, args)

        if args.eval_teacher_prune_ratio > 0:
            print(
                f"Prune teacher weight during evaluation. Ratio: {args.eval_teacher_prune_ratio}"
            )
            noise_teacher = deepcopy(teacher)
            noise_teacher.prune_weight_bias(args.eval_teacher_prune_ratio)
        else:
            noise_teacher = teacher

        active_nodes = None

        print("=== Start ===")
        train_stats_op = initialize_train_stats_ops(teacher, student,
                                                    active_nodes, args)
        train_stats_op.label = "train"

        eval_stats_op = initialize_eval_stats_ops(teacher, student,
                                                  active_nodes, args)
        eval_stats_op.label = "eval"

        eval_train_stats_op = initialize_eval_stats_ops(
            teacher, student, active_nodes, args)
        eval_train_stats_op.label = "eval_train"

        if noise_teacher != teacher:
            eval_no_noise_stats_op = initialize_eval_stats_ops(
                teacher, student, active_nodes, args)
            eval_no_noise_stats_op.label = "eval_no_noise"
        else:
            eval_no_noise_stats_op = None

        cp = Namespace(trial_idx=0, all_stats=[], lr=None, epoch=0, \
                student=student, teacher=teacher, teacher_eval=noise_teacher, \
                train_stats_op=train_stats_op, eval_stats_op=eval_stats_op, \
                eval_train_stats_op=eval_train_stats_op, \
                eval_no_noise_stats_op=eval_no_noise_stats_op)

    # teacher.w0.bias.data.uniform_(-1, 0)
    # teacher.init_orth()

    # init_w(teacher.w0)
    # init_w(teacher.w1)
    # init_w(teacher.w2)

    # init_w2(teacher.w0, multiplier=args.init_multi)
    # init_w2(teacher.w1, multiplier=args.init_multi)
    # init_w2(teacher.w2, multiplier=args.init_multi)

    # pickle.dump(model2numpy(teacher), open("weights_gt.pickle", "wb"), protocol=2)

    while cp.trial_idx < args.num_trial:
        print(
            f"=== Trial {cp.trial_idx}, std = {args.data_std}, dataset = {args.dataset} ==="
        )

        # init_corrs[-1] = predict_last_order(student, teacher, args)
        # alter_last_layer = predict_last_order(student, teacher, args)

        # import pdb
        # pdb.set_trace()
        optimize(train_loader, eval_loader, cp, loss_func, args, lrs)
        cp.all_stats.append(cp.stats)
        cp.epoch = 0
        cp.trial_idx += 1

    torch.save(cp.all_stats, "stats.pickle")

    # print("Student network")
    # print(student.w1.weight)
    # print("Teacher network")
    # print(teacher.w1.weight)
    print(f"Working dir: {os.getcwd()}")