Exemplo n.º 1
0
def main(**kwargs):
    # set training args
    args = AttrDict(kwargs)
    refine_args_by_dataset(args)

    args.output_dir = get_output_dir_name(args.output_dir, args.dataset)

    comm = init_nnabla(ext_name="cudnn",
                       device_id=args.device_id,
                       type_config="float",
                       random_pseed=True)

    data_iterator = get_dataset(args, comm)

    model = Model(beta_strategy=args.beta_strategy,
                  num_diffusion_timesteps=args.num_diffusion_timesteps,
                  model_var_type=ModelVarType.get_vartype_from_key(
                      args.model_var_type),
                  attention_num_heads=args.num_attention_heads,
                  attention_resolutions=args.attention_resolutions,
                  scale_shift_norm=args.ssn,
                  base_channels=args.base_channels,
                  channel_mult=args.channel_mult,
                  num_res_blocks=args.num_res_blocks)

    # build graph
    x = nn.Variable(args.image_shape)  # assume data_iterator returns [0, 255]
    x_rescaled = x / 127.5 - 1  # rescale to [-1, 1]
    loss_dict, t = model.build_train_graph(
        x_rescaled,
        dropout=args.dropout,
        loss_scaling=None if args.loss_scaling == 1.0 else args.loss_scaling)
    assert loss_dict.batched_loss.shape == (args.batch_size, )
    assert t.shape == (args.batch_size, )
    assert t.persistent == True

    # optimizer
    solver = S.Adam()
    solver.set_parameters(nn.get_parameters())

    # for ema update
    # Note: this should be defined after solver.set_parameters() to avoid update by solver.
    ema_op, ema_params = create_ema_op(nn.get_parameters(), 0.9999)
    dummy_solver_ema = S.Sgd()
    dummy_solver_ema.set_learning_rate(0)  # just in case
    dummy_solver_ema.set_parameters(ema_params)
    assert len(nn.get_parameters(grad_only=True)) == len(ema_params)
    assert len(nn.get_parameters(grad_only=False)) == 2 * len(ema_params)

    # for checkpoint
    solvers = {
        "main": solver,
        "ema": dummy_solver_ema,
    }

    start_iter = 0  # exclusive
    if args.resume:
        parent = os.path.dirname(os.path.abspath(args.output_dir))
        all_logs = sorted(
            fnmatch.filter(os.listdir(parent), "*{}*".format(args.dataset)))
        if len(all_logs):
            latest_dir = os.path.join(parent, all_logs[-1])
            checkpoints = sorted(
                fnmatch.filter(os.listdir(latest_dir), "checkpoint_*.json"))
            if len(checkpoints):
                latest_cp = os.path.join(latest_dir, checkpoints[-1])
                start_iter = load_checkpoint(latest_cp, solvers)

                for sname, slv in solvers.items():
                    slv.zero_grad()
    comm.barrier()

    # Reporter
    reporter = KVReporter(comm,
                          save_path=args.output_dir,
                          skip_kv_to_monitor=False)
    # set all keys before to prevent synchronization error
    for i in range(4):
        reporter.set_key(f"loss_q{i}")
        if is_learn_sigma(model.model_var_type):
            reporter.set_key(f"vlb_q{i}")

    image_dir = os.path.join(args.output_dir, "image")
    if comm.rank == 0:
        os.makedirs(image_dir, exist_ok=True)

    if args.progress:
        from tqdm import trange
        piter = trange(start_iter + 1,
                       args.n_iters + 1,
                       disable=comm.rank > 0,
                       ncols=0)
    else:
        piter = range(start_iter + 1, args.n_iters + 1)

    # dump config
    if comm.rank == 0:
        args.dump()
        write_yaml(os.path.join(args.output_dir, "config.yaml"), args)

    comm.barrier()

    for i in piter:
        # update solver's lr
        # cur_lr = get_warmup_lr(lr, args.n_warmup, i)
        solver.set_learning_rate(args.lr)

        # evaluate graph
        dummy_solver_ema.zero_grad()  # just in case
        solver.zero_grad()
        for accum_iter in range(args.accum):  # accumelate
            data, label = data_iterator.next()
            x.d = data.copy()

            loss_dict.loss.forward(clear_no_need_grad=True)

            all_reduce_cb = None
            if accum_iter == args.accum - 1:
                all_reduce_cb = comm.get_all_reduce_callback(
                    params=solver.get_parameters().values())

            loss_dict.loss.backward(clear_buffer=True,
                                    communicator_callbacks=all_reduce_cb)

            # logging
            # loss
            reporter.kv_mean("loss", loss_dict.loss)

            if is_learn_sigma(model.model_var_type):
                reporter.kv_mean("vlb", loss_dict.vlb)

            # loss for each quantile
            for j in range(args.batch_size):
                ti = t.d[j]
                q_level = int(ti) * 4 // args.num_diffusion_timesteps
                assert q_level in (
                    0, 1, 2, 3
                ), f"q_level should be one of [0, 1, 2, 3], but {q_level} is given."
                reporter.kv_mean(f"loss_q{q_level}",
                                 float(loss_dict.batched_loss.d[j]))

                if is_learn_sigma(model.model_var_type):
                    reporter.kv_mean(f"vlb_q{q_level}", loss_dict.vlb.d[j])

        # update
        if args.grad_clip > 0:
            solver.clip_grad_by_norm(args.grad_clip)
        solver.update()

        # update ema params
        ema_op.forward(clear_no_need_grad=True)

        # grad norm
        if args.dump_grad_norm:
            gnorm = sum_grad_norm(solver.get_parameters().values())
            reporter.kv_mean("grad", gnorm)

        # samples
        reporter.kv("samples", i * args.batch_size * comm.n_procs * args.accum)

        # iteration (only for no-progress)
        if not args.progress:
            reporter.kv("iteration", i)

        if i % args.show_interval == 0:
            if args.progress:
                desc = reporter.desc(reset=True, sync=True)
                piter.set_description(desc=desc)
            else:
                reporter.dump(file=sys.stdout if comm.rank == 0 else None,
                              reset=True,
                              sync=True)

            reporter.flush_monitor(i)

        if i > 0 and i % args.save_interval == 0:
            if comm.rank == 0:
                save_checkpoint(args.output_dir, i, solvers, n_keeps=3)

            comm.barrier()

        if i > 0 and i % args.gen_interval == 0:
            # sampling
            sample_out, _, _ = model.sample(shape=(16, ) + x.shape[1:],
                                            use_ema=True,
                                            progress=False)
            assert sample_out.shape == (16, ) + args.image_shape[1:]

            # scale back to [0, 255]
            sample_out = (sample_out + 1) * 127.5

            save_path = os.path.join(image_dir, f"gen_{i}_{comm.rank}.png")
            save_tiled_image(sample_out.astype(np.uint8), save_path)
Exemplo n.º 2
0
 def save_checkpoint(self, path, epoch):
     # path: saved_models_dir
     from neu import checkpoint_util as cu
     os.makedirs(path, exist_ok=True)
     cu.save_checkpoint(path, epoch, self.solver)
Exemplo n.º 3
0
def train():
    bs_train, bs_valid = args.train_batch_size, args.val_batch_size
    extension_module = args.context
    ctx = get_extension_context(
        extension_module, device_id=args.device_id, type_config=args.type_config
    )
    nn.set_default_context(ctx)

    if args.input:
        train_loader, val_loader, n_train_samples, n_val_samples = load_data(
            bs_train, bs_valid
        )

    else:
        train_data_source = data_source_cifar10(
            train=True, shuffle=True, label_shuffle=True
        )
        val_data_source = data_source_cifar10(train=False, shuffle=False)
        n_train_samples = len(train_data_source.labels)
        n_val_samples = len(val_data_source.labels)
        # Data Iterator
        train_loader = data_iterator(
            train_data_source, bs_train, None, False, False)
        val_loader = data_iterator(
            val_data_source, bs_valid, None, False, False)

        if args.shuffle_label:
            if not os.path.exists(args.output):
                os.makedirs(args.output)
            np.save(os.path.join(args.output, "x_train.npy"),
                    train_data_source.images)
            np.save(
                os.path.join(args.output, "y_shuffle_train.npy"),
                train_data_source.labels,
            )
            np.save(os.path.join(args.output, "y_train.npy"),
                    train_data_source.raw_label)
            np.save(os.path.join(args.output, "x_val.npy"),
                    val_data_source.images)
            np.save(os.path.join(args.output, "y_val.npy"),
                    val_data_source.labels)

    if args.model == "resnet23":
        model_prediction = resnet23_prediction
    elif args.model == "resnet56":
        model_prediction = resnet56_prediction
    prediction = functools.partial(
        model_prediction, ncls=10, nmaps=64, act=F.relu, seed=args.seed)

    # Create training graphs
    test = False
    image_train = nn.Variable((bs_train, 3, 32, 32))
    label_train = nn.Variable((bs_train, 1))
    pred_train, _ = prediction(image_train, test)

    loss_train = loss_function(pred_train, label_train)

    # Create validation graph
    test = True
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    label_valid = nn.Variable((bs_valid, 1))
    pred_valid, _ = prediction(image_valid, test)
    loss_val = loss_function(pred_valid, label_valid)

    for param in nn.get_parameters().values():
        param.grad.zero()

    cfg = read_yaml("./learning_rate.yaml")
    print(cfg)
    lr_sched = create_learning_rate_scheduler(cfg.learning_rate_config)
    solver = S.Momentum(momentum=0.9, lr=lr_sched.get_lr())
    solver.set_parameters(nn.get_parameters())
    start_point = 0

    if args.checkpoint is not None:
        # load weights and solver state info from specified checkpoint file.
        start_point = load_checkpoint(args.checkpoint, solver)

    # Create monitor
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed

    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=1)
    monitor_err = MonitorSeries("Training error", monitor, interval=1)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=1)
    monitor_verr = MonitorSeries("Test error", monitor, interval=1)
    monitor_vloss = MonitorSeries("Test loss", monitor, interval=1)

    # save_nnp
    contents = save_nnp({"x": image_valid}, {"y": pred_valid}, bs_valid)
    save.save(
        os.path.join(args.model_save_path,
                     (args.model+"_epoch0_result.nnp")), contents
    )

    train_iter = math.ceil(n_train_samples / bs_train)
    val_iter = math.ceil(n_val_samples / bs_valid)

    # Training-loop
    for i in range(start_point, args.train_epochs):
        lr_sched.set_epoch(i)
        solver.set_learning_rate(lr_sched.get_lr())
        print("Learning Rate: ", lr_sched.get_lr())
        # Validation
        ve = 0.0
        vloss = 0.0
        print("## Validation")
        for j in range(val_iter):
            image, label = val_loader.next()
            image_valid.d = image
            label_valid.d = label
            loss_val.forward()
            vloss += loss_val.data.data.copy() * bs_valid
            ve += categorical_error(pred_valid.d, label)
        ve /= args.val_iter
        vloss /= n_val_samples

        monitor_verr.add(i, ve)
        monitor_vloss.add(i, vloss)

        if int(i % args.model_save_interval) == 0:
            # save checkpoint file
            save_checkpoint(args.model_save_path, i, solver)

        # Forward/Zerograd/Backward
        print("## Training")
        e = 0.0
        loss = 0.0
        for k in range(train_iter):

            image, label = train_loader.next()
            image_train.d = image
            label_train.d = label
            loss_train.forward()
            solver.zero_grad()
            loss_train.backward()
            solver.update()
            e += categorical_error(pred_train.d, label_train.d)
            loss += loss_train.data.data.copy() * bs_train
        e /= train_iter
        loss /= n_train_samples

        e = categorical_error(pred_train.d, label_train.d)
        monitor_loss.add(i, loss)
        monitor_err.add(i, e)
        monitor_time.add(i)

    nn.save_parameters(
        os.path.join(args.model_save_path, "params_%06d.h5" %
                     (args.train_epochs))
    )

    # save_nnp_lastepoch
    contents = save_nnp({"x": image_valid}, {"y": pred_valid}, bs_valid)
    save.save(os.path.join(args.model_save_path,
              (args.model+"_result.nnp")), contents)