Beispiel #1
0
def test_solver_zeroing():
    xs = [nn.Variable([2, 3, 4], need_grad=True) for _ in range(3)]

    s = S.Sgd(1)
    s.set_parameters({str(i): x for i, x in enumerate(xs)})

    for x in xs:
        x.data.fill(1)
        x.grad.zero()

    s.weight_decay(1.0)
    s.update()
    for x in xs:
        # Grad is not referenced since neither weight decay nor update is performed.
        assert x.grad.zeroing
        assert_allclose(x.d, 1)

    for x in xs:
        x.grad.fill(1)

    s.weight_decay(0.1)
    s.update()

    for x in xs:
        assert_allclose(x.d, 1 - (1 + 0.1))
Beispiel #2
0
def test_simple_loop():
    nn.clear_parameters()

    x = nn.Variable.from_numpy_array(np.random.randn(10, 3, 128, 128))
    t = nn.Variable.from_numpy_array(np.random.randint(0, 100, (10, )))

    unet = UNet(num_classes=1,
                model_channels=128,
                output_channels=3,
                num_res_blocks=2,
                attention_resolutions=(16, 8),
                attention_num_heads=4,
                channel_mult=(1, 1, 2, 2, 4, 4))
    y = unet(x, t)

    loss = F.mean(F.squared_error(y, x))

    import nnabla.solvers as S
    solver = S.Sgd()
    solver.set_parameters(nn.get_parameters())

    from tqdm import trange
    tr = trange(100)
    for i in tr:
        loss.forward(clear_no_need_grad=True)
        solver.zero_grad()
        loss.backward(clear_buffer=True)
        solver.update()

        tr.set_description(f"diff: {loss.d.copy():.5f}")
Beispiel #3
0
def training(steps, learning_rate):
    solver = S.Sgd(learning_rate)
    solver.set_parameters(
        nn.get_parameters())  # Set parameter variables to be updated.
    for i in range(steps):
        x.d, t.d = data.next()
        loss.forward()
        solver.zero_grad()  # Initialize gradients of all parameters to zero.
        loss.backward()
        solver.weight_decay(1e-5)  # Applying weight decay as an regularization
        solver.update()
        if i % 100 == 0:  # Print for each 10 iterations
            print(i, loss.d)
def data_distill(model, uniform_data_iterator, num_iter):
    generated_img = []
    for _ in range(uniform_data_iterator.size //
                   uniform_data_iterator.batch_size):
        img, _ = uniform_data_iterator.next()
        dst_img = nn.Variable(img.shape, need_grad=True)
        dst_img.d = img
        img_params = OrderedDict()
        img_params['img'] = dst_img

        init_lr = 0.5
        solver = S.Adam(alpha=init_lr)
        solver.set_parameters(img_params)
        #scheduler = lr_scheduler.CosineScheduler(init_lr=0.5, max_iter=num_iter)
        scheduler = ReduceLROnPlateauScheduler(init_lr=init_lr,
                                               min_lr=1e-4,
                                               verbose=False,
                                               patience=100)
        dummy_solver = S.Sgd(lr=0)
        dummy_solver.set_parameters(nn.get_parameters())

        for it in tqdm(range(num_iter)):
            lr = scheduler.get_learning_rate()
            solver.set_learning_rate(lr)

            global outs
            outs = []
            global batch_stats
            batch_stats = []

            y = model(denormalize(dst_img),
                      force_global_pooling=True,
                      training=False)  # denormalize to U(0, 255)
            y.forward(function_post_hook=get_output)
            assert len(outs) == len(batch_stats)
            loss = zeroq_loss(batch_stats, outs, dst_img)
            loss.forward()
            solver.zero_grad()
            dummy_solver.zero_grad()
            loss.backward()
            solver.weight_decay(1e-6)
            solver.update()

            scheduler.update_lr(loss.d)

        generated_img.append(dst_img.d)

    return generated_img
Beispiel #5
0
def main():
    extension_module = args.context
    ctx = get_extension_context(extension_module,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    with h5py.File(os.path.join(data_dir, "info.h5"), "r") as hf:
        Y_label = hf["label"][:]
        X_feature = hf["feature"]["train"][:]
        Y_train = hf["output"]["train"][:]
        parameter_list = [hf["param"]["weight"][:], hf["param"]["bias"][:]]

    feature_shape = X_feature.shape

    # computation graph
    feature_valid = nn.Variable((feature_shape[0], feature_shape[1]))
    output_valid = nn.Variable((feature_shape[0], 10))
    with nn.parameter_scope("final_layer"):
        pred = PF.affine(feature_valid, 10)
    loss, phi, l2 = loss_function(pred, output_valid, len(X_feature))

    # parameter initialized
    for weight, param in zip(nn.get_parameters().values(), parameter_list):
        weight.data.copy_from(nn.NdArray.from_numpy_array(param))

    solver = S.Sgd(lr=1.0)
    solver.set_parameters(nn.get_parameters())
    for ind, (name, param) in enumerate(nn.get_parameters().items()):
        param.grad.zero()
        param.need_grad = True

    calculate_alpha(
        parameter_list,
        X_feature,
        Y_train,
        Y_label,
        feature_valid,
        solver,
        output_valid,
        pred,
        loss,
        phi,
        l2,
    )
Beispiel #6
0
def training_rnn(steps, learning_rate):
    solver = S.Sgd(learning_rate)
    solver.set_parameters(
        nn.get_parameters())  # Set parameter variables to be updated.
    for i in range(steps):
        minibatch = data.next()
        img, t.d = minibatch
        seq_img = split_grid4(img)
        h0.d = 0  # Initialize as 0
        for x, subimg in zip(seq_x, seq_img):
            x.d = subimg
        loss.forward()
        solver.zero_grad()  # Initialize gradients of all parameters to zero.
        loss.backward()
        solver.weight_decay(1e-5)  # Applying weight decay as an regularization
        solver.update()
        if i % 100 == 0:  # Print for each 10 iterations
            print(i, loss.d)
Beispiel #7
0
xi -= 1
id(xi)
# -
# !The following doesn't perform substitution but assigns a new NdArray object to `xi`.
# !xi = xi + 1

# !The following copies the result of `xi + 1` to `xi`.
xi.copy_from(xi + 1)
assert np.all(xi.data == (np.arange(4).reshape(2, 2) + 1))

# Inplace operations like `+=`, `*=` can also be used (more efficient).
xi += 1
assert np.all(xi.data == (np.arange(4).reshape(2, 2) + 2))

# ## Solver
solver = S.Sgd(lr=0.00001)
solver.set_parameters(nn.get_parameters())
# -
# !Set random data
x.d = np.random.randn(*x.shape)
label.d = np.random.randn(*label.shape)

# !Forward
loss.forward()

# -
solver.zero_grad()
loss.backward()
solver.update()

Beispiel #8
0
    label = nn.Variable([g_batch_size, total_box_num, g_class_num + 4])
    ssd_train_loss = ssd_loss(ssd_conf_output,
                              ssd_loc_output,
                              label,
                              _alpha=g_ssd_loss_alpha)
    print('ssd_train_loss = {}'.format(ssd_train_loss))

    # [get dataset]
    my_tdata = load_dataset_npz(g_train_data_path)
    my_vdata = load_dataset_npz(g_test_data_path)
    tdata_num = len(my_tdata['image'])
    iter_num_max = int(np.ceil(tdata_num / g_batch_size))

    # [def solver]
    if g_optimizer == 'SGD':
        solver = S.Sgd(g_default_learning_rate)
    elif g_optimizer == 'Adam':
        solver = S.Adam(g_default_learning_rate)
    elif g_optimizer == 'AdaBound':
        solver = S.AdaBound(g_default_learning_rate)
    solver.set_parameters(coef_dict_source)

    # [def monitor]
    monitor = Monitor(g_save_log_dir)
    monitor_loss = MonitorSeries("Training loss",
                                 monitor,
                                 interval=g_monitor_interval)
    monitor_time = MonitorTimeElapsed("Training time",
                                      monitor,
                                      interval=g_monitor_interval)
    monitor_verr = MonitorSeries("Validation error", monitor, interval=1)
Beispiel #9
0
# ### Backward propagation through the graph
print(nn.get_parameters())
# -
for param in nn.get_parameters().values():
    print(param)
    param.grad.zero()
# -
# !Compute backward
loss.backward()
# !Showing gradients.
for name, param in nn.get_parameters().items():
    print(name, param.shape, param.g.flat[:20])  # Showing first 20.
# ### Optimizing parameters (=Training)
# !Create a solver (gradient-based optimizer)
learning_rate = 1e-3
solver = S.Sgd(learning_rate)
solver.set_parameters(
    nn.get_parameters())  # Set parameter variables to be updated.

# !One step of training
x.d, t.d = data.next()
loss.forward()
solver.zero_grad()  # Initialize gradients of all parameters to zero.
loss.backward()
solver.weight_decay(1e-5)  # Applying weight decay as an regularization
solver.update()
print(loss.d)

# -
for i in range(1000):
    x.d, t.d = data.next()
Beispiel #10
0
def main(**kwargs):
    # set training args
    args = AttrDict(kwargs)
    refine_args_by_dataset(args)

    args.output_dir = get_output_dir_name(args.output_dir, args.dataset)

    comm = init_nnabla(ext_name="cudnn",
                       device_id=args.device_id,
                       type_config="float",
                       random_pseed=True)

    data_iterator = get_dataset(args, comm)

    model = Model(beta_strategy=args.beta_strategy,
                  num_diffusion_timesteps=args.num_diffusion_timesteps,
                  model_var_type=ModelVarType.get_vartype_from_key(
                      args.model_var_type),
                  attention_num_heads=args.num_attention_heads,
                  attention_resolutions=args.attention_resolutions,
                  scale_shift_norm=args.ssn,
                  base_channels=args.base_channels,
                  channel_mult=args.channel_mult,
                  num_res_blocks=args.num_res_blocks)

    # build graph
    x = nn.Variable(args.image_shape)  # assume data_iterator returns [0, 255]
    x_rescaled = x / 127.5 - 1  # rescale to [-1, 1]
    loss_dict, t = model.build_train_graph(
        x_rescaled,
        dropout=args.dropout,
        loss_scaling=None if args.loss_scaling == 1.0 else args.loss_scaling)
    assert loss_dict.batched_loss.shape == (args.batch_size, )
    assert t.shape == (args.batch_size, )
    assert t.persistent == True

    # optimizer
    solver = S.Adam()
    solver.set_parameters(nn.get_parameters())

    # for ema update
    # Note: this should be defined after solver.set_parameters() to avoid update by solver.
    ema_op, ema_params = create_ema_op(nn.get_parameters(), 0.9999)
    dummy_solver_ema = S.Sgd()
    dummy_solver_ema.set_learning_rate(0)  # just in case
    dummy_solver_ema.set_parameters(ema_params)
    assert len(nn.get_parameters(grad_only=True)) == len(ema_params)
    assert len(nn.get_parameters(grad_only=False)) == 2 * len(ema_params)

    # for checkpoint
    solvers = {
        "main": solver,
        "ema": dummy_solver_ema,
    }

    start_iter = 0  # exclusive
    if args.resume:
        parent = os.path.dirname(os.path.abspath(args.output_dir))
        all_logs = sorted(
            fnmatch.filter(os.listdir(parent), "*{}*".format(args.dataset)))
        if len(all_logs):
            latest_dir = os.path.join(parent, all_logs[-1])
            checkpoints = sorted(
                fnmatch.filter(os.listdir(latest_dir), "checkpoint_*.json"))
            if len(checkpoints):
                latest_cp = os.path.join(latest_dir, checkpoints[-1])
                start_iter = load_checkpoint(latest_cp, solvers)

                for sname, slv in solvers.items():
                    slv.zero_grad()
    comm.barrier()

    # Reporter
    reporter = KVReporter(comm,
                          save_path=args.output_dir,
                          skip_kv_to_monitor=False)
    # set all keys before to prevent synchronization error
    for i in range(4):
        reporter.set_key(f"loss_q{i}")
        if is_learn_sigma(model.model_var_type):
            reporter.set_key(f"vlb_q{i}")

    image_dir = os.path.join(args.output_dir, "image")
    if comm.rank == 0:
        os.makedirs(image_dir, exist_ok=True)

    if args.progress:
        from tqdm import trange
        piter = trange(start_iter + 1,
                       args.n_iters + 1,
                       disable=comm.rank > 0,
                       ncols=0)
    else:
        piter = range(start_iter + 1, args.n_iters + 1)

    # dump config
    if comm.rank == 0:
        args.dump()
        write_yaml(os.path.join(args.output_dir, "config.yaml"), args)

    comm.barrier()

    for i in piter:
        # update solver's lr
        # cur_lr = get_warmup_lr(lr, args.n_warmup, i)
        solver.set_learning_rate(args.lr)

        # evaluate graph
        dummy_solver_ema.zero_grad()  # just in case
        solver.zero_grad()
        for accum_iter in range(args.accum):  # accumelate
            data, label = data_iterator.next()
            x.d = data.copy()

            loss_dict.loss.forward(clear_no_need_grad=True)

            all_reduce_cb = None
            if accum_iter == args.accum - 1:
                all_reduce_cb = comm.get_all_reduce_callback(
                    params=solver.get_parameters().values())

            loss_dict.loss.backward(clear_buffer=True,
                                    communicator_callbacks=all_reduce_cb)

            # logging
            # loss
            reporter.kv_mean("loss", loss_dict.loss)

            if is_learn_sigma(model.model_var_type):
                reporter.kv_mean("vlb", loss_dict.vlb)

            # loss for each quantile
            for j in range(args.batch_size):
                ti = t.d[j]
                q_level = int(ti) * 4 // args.num_diffusion_timesteps
                assert q_level in (
                    0, 1, 2, 3
                ), f"q_level should be one of [0, 1, 2, 3], but {q_level} is given."
                reporter.kv_mean(f"loss_q{q_level}",
                                 float(loss_dict.batched_loss.d[j]))

                if is_learn_sigma(model.model_var_type):
                    reporter.kv_mean(f"vlb_q{q_level}", loss_dict.vlb.d[j])

        # update
        if args.grad_clip > 0:
            solver.clip_grad_by_norm(args.grad_clip)
        solver.update()

        # update ema params
        ema_op.forward(clear_no_need_grad=True)

        # grad norm
        if args.dump_grad_norm:
            gnorm = sum_grad_norm(solver.get_parameters().values())
            reporter.kv_mean("grad", gnorm)

        # samples
        reporter.kv("samples", i * args.batch_size * comm.n_procs * args.accum)

        # iteration (only for no-progress)
        if not args.progress:
            reporter.kv("iteration", i)

        if i % args.show_interval == 0:
            if args.progress:
                desc = reporter.desc(reset=True, sync=True)
                piter.set_description(desc=desc)
            else:
                reporter.dump(file=sys.stdout if comm.rank == 0 else None,
                              reset=True,
                              sync=True)

            reporter.flush_monitor(i)

        if i > 0 and i % args.save_interval == 0:
            if comm.rank == 0:
                save_checkpoint(args.output_dir, i, solvers, n_keeps=3)

            comm.barrier()

        if i > 0 and i % args.gen_interval == 0:
            # sampling
            sample_out, _, _ = model.sample(shape=(16, ) + x.shape[1:],
                                            use_ema=True,
                                            progress=False)
            assert sample_out.shape == (16, ) + args.image_shape[1:]

            # scale back to [0, 255]
            sample_out = (sample_out + 1) * 127.5

            save_path = os.path.join(image_dir, f"gen_{i}_{comm.rank}.png")
            save_tiled_image(sample_out.astype(np.uint8), save_path)
Beispiel #11
0
def infl_sgd(model_info_dict, file_dir_dict, use_all_params, need_evaluate):
    # params
    lr = model_info_dict['lr']
    seed = model_info_dict['seed']
    net_func = model_info_dict['net_func']
    batch_size = model_info_dict['batch_size']
    end_epoch = model_info_dict['end_epoch']
    target_epoch = model_info_dict['num_epochs']
    # files and dirs
    save_dir = file_dir_dict['save_dir']
    info_filename = file_dir_dict['info_filename']
    infl_filename = file_dir_dict['infl_filename']
    final_model_name = file_dir_dict['model_filename']
    final_model_path = os.path.join(save_dir, 'epoch%02d' % (target_epoch - 1),
                                    'weights', final_model_name)
    input_dir_name = os.path.dirname(file_dir_dict['train_csv'])

    # setup
    trainset, valset, image_shape, n_classes, ntr, nval = init_dataset(
        file_dir_dict['train_csv'], file_dir_dict['val_csv'], seed)
    n_channels, _h, _w = image_shape
    resize_size = get_image_size((_h, _w))
    idx_train = get_indices(ntr, seed)
    idx_val = get_indices(nval, seed)

    nn.load_parameters(final_model_path)
    trained_params = nn.get_parameters(grad_only=False)
    test = True

    grad_model = functools.partial(setup_model,
                                   net_func=net_func,
                                   n_classes=n_classes,
                                   n_channels=n_channels,
                                   resize_size=resize_size,
                                   test=test,
                                   reduction='sum')

    solver = S.Sgd(lr=lr)
    solver.set_parameters(trained_params)
    # gradient
    u = compute_gradient(grad_model, solver, valset, batch_size, idx_val,
                         target_epoch, resize_size)

    test = False
    infl_model = functools.partial(setup_model,
                                   net_func=net_func,
                                   n_classes=n_classes,
                                   n_channels=n_channels,
                                   resize_size=resize_size,
                                   test=test)
    # influence
    infl_dict = {}
    info = np.load(os.path.join(save_dir, info_filename), allow_pickle=True)
    loss_fn = None
    for epoch in tqdm(range(target_epoch - 1, end_epoch - 1, -1),
                      desc='calc influence (3/3 steps)'):
        for step_info in info[epoch][::-1]:
            idx, seeds, lr, step = step_info['idx'], step_info[
                'seeds'], step_info['lr'], step_info['step']
            fn = select_modelfile_for_infl(use_all_params, final_model_path,
                                           save_dir, epoch, step)
            _, loss_fn, input_image = adjust_batch_size(
                infl_model, solver, 1, loss_fn)
            nn.load_parameters(fn)
            params = nn.get_parameters(grad_only=False)
            solver = S.Sgd(lr=lr)
            solver.set_parameters(params)
            X = []
            y = []
            for i, seed in zip(idx, seeds):
                i = int(i)
                image, label = get_data(trainset,
                                        idx_train[i],
                                        resize_size,
                                        test,
                                        seed=seed)
                X.append(image)
                y.append(label)
                input_image["image"].d = image
                input_image["label"].d = label
                loss_fn.forward()
                solver.zero_grad()
                loss_fn.backward(clear_buffer=True)

                csv_idx = idx_train[i]
                infl = infl_dict.get(csv_idx, [0.0])[-1]
                for j, (key, param) in enumerate(
                        nn.get_parameters(grad_only=False).items()):
                    infl += lr * (u[key].d * param.g).sum() / idx.size

                # store infl
                file_name = trainset.get_filepath_to_data(csv_idx)
                file_name = os.path.join(input_dir_name, file_name)
                file_name = os.path.normpath(file_name)
                infl_dict[csv_idx] = [file_name, label, infl]

            # update u
            _, loss_fn, input_image = adjust_batch_size(
                infl_model, solver, len(idx), loss_fn)
            input_image["image"].d = X
            input_image["label"].d = np.array(y).reshape(-1, 1)
            loss_fn.forward()
            params = nn.get_parameters(grad_only=False)
            grad_params = {}
            for key, p in zip(params.keys(), nn.grad([loss_fn],
                                                     params.values())):
                grad_params[key] = p
            ug = 0
            # compute H[t]u[t]
            for key, uu in u.items():
                try:
                    ug += F.sum(uu * grad_params[key])
                except TypeError:
                    # cannot calc grad with batch normalization runnning mean and var
                    pass
            ug.forward()
            solver.zero_grad()
            ug.backward(clear_buffer=True)

            for j, (key, param) in enumerate(
                    nn.get_parameters(grad_only=False).items()):
                u[key].d -= lr * param.g / idx.size

        # sort by influence score
        infl_list = [val + [key] for key, val in infl_dict.items()]
        infl_list = sorted(infl_list, key=lambda x: (x[-2]))

        # save
        header = ['x:image', 'y:label', 'influence', 'datasource_index']
        data_type = 'object,int,float,int'
        if need_evaluate:
            save_infl_for_analysis(infl_list, use_all_params, save_dir,
                                   infl_filename, epoch, header, data_type)
    save_to_csv(filename=infl_filename,
                header=header,
                list_to_save=infl_list,
                data_type=data_type)
Beispiel #12
0
def retrain(model_info_dict, file_dir_dict, retrain_all, escape_list=[]):
    # params
    lr = model_info_dict['lr']
    seed = model_info_dict['seed']
    net_func = model_info_dict['net_func']
    batch_size = model_info_dict['batch_size']
    end_epoch = model_info_dict['num_epochs']
    start_epoch = model_info_dict['start_epoch']
    # files and dirs
    save_dir = file_dir_dict['save_dir']
    info_filename = file_dir_dict['info_filename']
    score_filename = file_dir_dict['score_filename']
    info_file_path = os.path.join(save_dir, info_filename)
    # setup
    trainset, valset, image_shape, n_classes, ntr, nval = init_dataset(
        file_dir_dict['train_csv'], file_dir_dict['val_csv'], seed)
    testset = setup_nnabla_dataset(file_dir_dict['test_csv'])
    ntest = testset.size
    n_channels, _h, _w = image_shape
    resize_size = get_image_size((_h, _w))

    # Create training graphs
    test = False
    train_model = functools.partial(setup_model,
                                    net_func=net_func,
                                    n_classes=n_classes,
                                    n_channels=n_channels,
                                    resize_size=resize_size,
                                    test=test)
    # Create validation graphs
    test = True
    val_model = functools.partial(setup_model,
                                  net_func=net_func,
                                  n_classes=n_classes,
                                  n_channels=n_channels,
                                  resize_size=resize_size,
                                  test=test)

    # setup optimizer (SGD)
    fn = select_model_file(retrain_all, save_dir, start_epoch)
    nn.load_parameters(fn)
    solver = S.Sgd(lr=lr)
    solver.set_parameters(nn.get_parameters(grad_only=False))
    # get shuffled index using designated seed
    idx_train = get_indices(ntr, seed)
    idx_val = get_indices(nval, seed)
    idx_test = get_indices(ntest, seed)

    # training
    info = np.load(info_file_path, allow_pickle=True)
    score = []
    loss_train = None
    for epoch in tqdm(range(start_epoch, end_epoch), desc='retrain'):
        for step_info in info[epoch]:
            idx, seeds, lr = step_info['idx'], step_info['seeds'], step_info[
                'lr']
            X, y = get_batch_data(trainset,
                                  idx_train,
                                  idx,
                                  resize_size,
                                  test=False,
                                  seeds=seeds,
                                  escape_list=escape_list)
            if len(X) == 0:
                continue
            _, loss_train, input_image_train = adjust_batch_size(
                train_model, solver, len(X), loss_train)
            input_image_train["image"].d = X
            input_image_train["label"].d = y

            loss_train.forward()
            solver.zero_grad()
            loss_train.backward(clear_buffer=True)
            for key, param in nn.get_parameters(grad_only=False).items():
                param.g *= len(X) / idx.size
            solver.update()
        # evaluation
        loss_val, acc_val = eval_model(val_model, solver, valset, idx_val,
                                       batch_size, resize_size)
        loss_test, acc_test = eval_model(val_model, solver, testset, idx_test,
                                         batch_size, resize_size)
        score.append((loss_val, loss_test, acc_val, acc_test))
        # save
    save_to_csv(filename=score_filename,
                header=[
                    'val_loss',
                    'test_loss',
                    'val_accuracy',
                    'test_accuracy',
                ],
                list_to_save=score,
                data_type='float,float,float,float')
Beispiel #13
0
def train(model_info_dict,
          file_dir_dict,
          use_all_params,
          need_evaluate,
          bundle_size=200):
    # params
    lr = model_info_dict['lr']
    seed = model_info_dict['seed']
    net_func = model_info_dict['net_func']
    batch_size = model_info_dict['batch_size']
    num_epochs = model_info_dict['num_epochs']
    infl_end_epoch = model_info_dict['end_epoch']
    # files and dirs
    save_dir = file_dir_dict['save_dir']
    info_filename = file_dir_dict['info_filename']
    model_filename = file_dir_dict['model_filename']
    score_filename = file_dir_dict['score_filename']
    # setup
    trainset, valset, image_shape, n_classes, ntr, nval = init_dataset(
        file_dir_dict['train_csv'], file_dir_dict['val_csv'], seed)
    n_channels, _h, _w = image_shape
    resize_size = get_image_size((_h, _w))
    # Create training graphs
    test = False
    train_model = functools.partial(setup_model,
                                    net_func=net_func,
                                    n_classes=n_classes,
                                    n_channels=n_channels,
                                    resize_size=resize_size,
                                    test=test)
    # Create validation graphs
    test = True
    val_model = functools.partial(setup_model,
                                  net_func=net_func,
                                  n_classes=n_classes,
                                  n_channels=n_channels,
                                  resize_size=resize_size,
                                  test=test)
    # setup optimizer (SGD)
    solver = S.Sgd(lr=lr)
    solver.set_parameters(nn.get_parameters(grad_only=False))

    # get shuffled index using designated seed
    idx_train = get_indices(ntr, seed)
    idx_val = get_indices(nval, seed)

    # training
    seed_train = 0
    info = []
    score = []
    loss_train = None
    for epoch in tqdm(range(num_epochs), desc='training (1/3 steps)'):
        idx = get_batch_indices(ntr, batch_size, seed=epoch)
        epoch_info = []
        c = 0
        k = 0
        params_dict = {}
        for j, i in enumerate(idx):
            seeds = list(range(seed_train, seed_train + i.size))
            seed_train += i.size
            epoch_info.append({
                'epoch': epoch,
                'step': j,
                'idx': i,
                'lr': lr,
                'seeds': seeds
            })
            if (use_all_params) & (epoch >= infl_end_epoch):
                params_dict, c, k = save_all_params(params_dict,
                                                    c, k, j, bundle_size,
                                                    len(idx), save_dir, epoch)
            X, y = get_batch_data(trainset,
                                  idx_train,
                                  i,
                                  resize_size,
                                  test=False,
                                  seeds=seeds)
            _, loss_train, input_image_train = adjust_batch_size(
                train_model, solver, len(X), loss_train)
            input_image_train["image"].d = X
            input_image_train["label"].d = y

            loss_train.forward()
            solver.zero_grad()
            loss_train.backward(clear_buffer=True)
            solver.update()
        info.append(epoch_info)
        # save if params are necessary for calculating influence
        if epoch >= infl_end_epoch - 1:
            dn = os.path.join(save_dir, 'epoch%02d' % (epoch), 'weights')
            ensure_dir(dn)
            nn.save_parameters(os.path.join(dn, model_filename),
                               params=nn.get_parameters(grad_only=False),
                               extension=".h5")
        # evaluation
        if need_evaluate:
            loss_tr, acc_tr = eval_model(val_model, solver, trainset,
                                         idx_train, batch_size, resize_size)
            loss_val, acc_val = eval_model(val_model, solver, valset, idx_val,
                                           batch_size, resize_size)
            score.append((loss_tr, loss_val, acc_tr, acc_val))
    # save epoch and step info
    np.save(os.path.join(save_dir, info_filename), arr=info)
    # save score
    if need_evaluate:
        save_to_csv(filename=score_filename,
                    header=[
                        'train_loss', 'val_loss', 'train_accuracy',
                        'val_accuracy'
                    ],
                    list_to_save=score,
                    data_type='float,float,float,float')