def main(**kwargs): # set training args args = AttrDict(kwargs) refine_args_by_dataset(args) args.output_dir = get_output_dir_name(args.output_dir, args.dataset) comm = init_nnabla(ext_name="cudnn", device_id=args.device_id, type_config="float", random_pseed=True) data_iterator = get_dataset(args, comm) model = Model(beta_strategy=args.beta_strategy, num_diffusion_timesteps=args.num_diffusion_timesteps, model_var_type=ModelVarType.get_vartype_from_key( args.model_var_type), attention_num_heads=args.num_attention_heads, attention_resolutions=args.attention_resolutions, scale_shift_norm=args.ssn, base_channels=args.base_channels, channel_mult=args.channel_mult, num_res_blocks=args.num_res_blocks) # build graph x = nn.Variable(args.image_shape) # assume data_iterator returns [0, 255] x_rescaled = x / 127.5 - 1 # rescale to [-1, 1] loss_dict, t = model.build_train_graph( x_rescaled, dropout=args.dropout, loss_scaling=None if args.loss_scaling == 1.0 else args.loss_scaling) assert loss_dict.batched_loss.shape == (args.batch_size, ) assert t.shape == (args.batch_size, ) assert t.persistent == True # optimizer solver = S.Adam() solver.set_parameters(nn.get_parameters()) # for ema update # Note: this should be defined after solver.set_parameters() to avoid update by solver. ema_op, ema_params = create_ema_op(nn.get_parameters(), 0.9999) dummy_solver_ema = S.Sgd() dummy_solver_ema.set_learning_rate(0) # just in case dummy_solver_ema.set_parameters(ema_params) assert len(nn.get_parameters(grad_only=True)) == len(ema_params) assert len(nn.get_parameters(grad_only=False)) == 2 * len(ema_params) # for checkpoint solvers = { "main": solver, "ema": dummy_solver_ema, } start_iter = 0 # exclusive if args.resume: parent = os.path.dirname(os.path.abspath(args.output_dir)) all_logs = sorted( fnmatch.filter(os.listdir(parent), "*{}*".format(args.dataset))) if len(all_logs): latest_dir = os.path.join(parent, all_logs[-1]) checkpoints = sorted( fnmatch.filter(os.listdir(latest_dir), "checkpoint_*.json")) if len(checkpoints): latest_cp = os.path.join(latest_dir, checkpoints[-1]) start_iter = load_checkpoint(latest_cp, solvers) for sname, slv in solvers.items(): slv.zero_grad() comm.barrier() # Reporter reporter = KVReporter(comm, save_path=args.output_dir, skip_kv_to_monitor=False) # set all keys before to prevent synchronization error for i in range(4): reporter.set_key(f"loss_q{i}") if is_learn_sigma(model.model_var_type): reporter.set_key(f"vlb_q{i}") image_dir = os.path.join(args.output_dir, "image") if comm.rank == 0: os.makedirs(image_dir, exist_ok=True) if args.progress: from tqdm import trange piter = trange(start_iter + 1, args.n_iters + 1, disable=comm.rank > 0, ncols=0) else: piter = range(start_iter + 1, args.n_iters + 1) # dump config if comm.rank == 0: args.dump() write_yaml(os.path.join(args.output_dir, "config.yaml"), args) comm.barrier() for i in piter: # update solver's lr # cur_lr = get_warmup_lr(lr, args.n_warmup, i) solver.set_learning_rate(args.lr) # evaluate graph dummy_solver_ema.zero_grad() # just in case solver.zero_grad() for accum_iter in range(args.accum): # accumelate data, label = data_iterator.next() x.d = data.copy() loss_dict.loss.forward(clear_no_need_grad=True) all_reduce_cb = None if accum_iter == args.accum - 1: all_reduce_cb = comm.get_all_reduce_callback( params=solver.get_parameters().values()) loss_dict.loss.backward(clear_buffer=True, communicator_callbacks=all_reduce_cb) # logging # loss reporter.kv_mean("loss", loss_dict.loss) if is_learn_sigma(model.model_var_type): reporter.kv_mean("vlb", loss_dict.vlb) # loss for each quantile for j in range(args.batch_size): ti = t.d[j] q_level = int(ti) * 4 // args.num_diffusion_timesteps assert q_level in ( 0, 1, 2, 3 ), f"q_level should be one of [0, 1, 2, 3], but {q_level} is given." reporter.kv_mean(f"loss_q{q_level}", float(loss_dict.batched_loss.d[j])) if is_learn_sigma(model.model_var_type): reporter.kv_mean(f"vlb_q{q_level}", loss_dict.vlb.d[j]) # update if args.grad_clip > 0: solver.clip_grad_by_norm(args.grad_clip) solver.update() # update ema params ema_op.forward(clear_no_need_grad=True) # grad norm if args.dump_grad_norm: gnorm = sum_grad_norm(solver.get_parameters().values()) reporter.kv_mean("grad", gnorm) # samples reporter.kv("samples", i * args.batch_size * comm.n_procs * args.accum) # iteration (only for no-progress) if not args.progress: reporter.kv("iteration", i) if i % args.show_interval == 0: if args.progress: desc = reporter.desc(reset=True, sync=True) piter.set_description(desc=desc) else: reporter.dump(file=sys.stdout if comm.rank == 0 else None, reset=True, sync=True) reporter.flush_monitor(i) if i > 0 and i % args.save_interval == 0: if comm.rank == 0: save_checkpoint(args.output_dir, i, solvers, n_keeps=3) comm.barrier() if i > 0 and i % args.gen_interval == 0: # sampling sample_out, _, _ = model.sample(shape=(16, ) + x.shape[1:], use_ema=True, progress=False) assert sample_out.shape == (16, ) + args.image_shape[1:] # scale back to [0, 255] sample_out = (sample_out + 1) * 127.5 save_path = os.path.join(image_dir, f"gen_{i}_{comm.rank}.png") save_tiled_image(sample_out.astype(np.uint8), save_path)
def save_checkpoint(self, path, epoch): # path: saved_models_dir from neu import checkpoint_util as cu os.makedirs(path, exist_ok=True) cu.save_checkpoint(path, epoch, self.solver)
def train(): bs_train, bs_valid = args.train_batch_size, args.val_batch_size extension_module = args.context ctx = get_extension_context( extension_module, device_id=args.device_id, type_config=args.type_config ) nn.set_default_context(ctx) if args.input: train_loader, val_loader, n_train_samples, n_val_samples = load_data( bs_train, bs_valid ) else: train_data_source = data_source_cifar10( train=True, shuffle=True, label_shuffle=True ) val_data_source = data_source_cifar10(train=False, shuffle=False) n_train_samples = len(train_data_source.labels) n_val_samples = len(val_data_source.labels) # Data Iterator train_loader = data_iterator( train_data_source, bs_train, None, False, False) val_loader = data_iterator( val_data_source, bs_valid, None, False, False) if args.shuffle_label: if not os.path.exists(args.output): os.makedirs(args.output) np.save(os.path.join(args.output, "x_train.npy"), train_data_source.images) np.save( os.path.join(args.output, "y_shuffle_train.npy"), train_data_source.labels, ) np.save(os.path.join(args.output, "y_train.npy"), train_data_source.raw_label) np.save(os.path.join(args.output, "x_val.npy"), val_data_source.images) np.save(os.path.join(args.output, "y_val.npy"), val_data_source.labels) if args.model == "resnet23": model_prediction = resnet23_prediction elif args.model == "resnet56": model_prediction = resnet56_prediction prediction = functools.partial( model_prediction, ncls=10, nmaps=64, act=F.relu, seed=args.seed) # Create training graphs test = False image_train = nn.Variable((bs_train, 3, 32, 32)) label_train = nn.Variable((bs_train, 1)) pred_train, _ = prediction(image_train, test) loss_train = loss_function(pred_train, label_train) # Create validation graph test = True image_valid = nn.Variable((bs_valid, 3, 32, 32)) label_valid = nn.Variable((bs_valid, 1)) pred_valid, _ = prediction(image_valid, test) loss_val = loss_function(pred_valid, label_valid) for param in nn.get_parameters().values(): param.grad.zero() cfg = read_yaml("./learning_rate.yaml") print(cfg) lr_sched = create_learning_rate_scheduler(cfg.learning_rate_config) solver = S.Momentum(momentum=0.9, lr=lr_sched.get_lr()) solver.set_parameters(nn.get_parameters()) start_point = 0 if args.checkpoint is not None: # load weights and solver state info from specified checkpoint file. start_point = load_checkpoint(args.checkpoint, solver) # Create monitor from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed monitor = Monitor(args.monitor_path) monitor_loss = MonitorSeries("Training loss", monitor, interval=1) monitor_err = MonitorSeries("Training error", monitor, interval=1) monitor_time = MonitorTimeElapsed("Training time", monitor, interval=1) monitor_verr = MonitorSeries("Test error", monitor, interval=1) monitor_vloss = MonitorSeries("Test loss", monitor, interval=1) # save_nnp contents = save_nnp({"x": image_valid}, {"y": pred_valid}, bs_valid) save.save( os.path.join(args.model_save_path, (args.model+"_epoch0_result.nnp")), contents ) train_iter = math.ceil(n_train_samples / bs_train) val_iter = math.ceil(n_val_samples / bs_valid) # Training-loop for i in range(start_point, args.train_epochs): lr_sched.set_epoch(i) solver.set_learning_rate(lr_sched.get_lr()) print("Learning Rate: ", lr_sched.get_lr()) # Validation ve = 0.0 vloss = 0.0 print("## Validation") for j in range(val_iter): image, label = val_loader.next() image_valid.d = image label_valid.d = label loss_val.forward() vloss += loss_val.data.data.copy() * bs_valid ve += categorical_error(pred_valid.d, label) ve /= args.val_iter vloss /= n_val_samples monitor_verr.add(i, ve) monitor_vloss.add(i, vloss) if int(i % args.model_save_interval) == 0: # save checkpoint file save_checkpoint(args.model_save_path, i, solver) # Forward/Zerograd/Backward print("## Training") e = 0.0 loss = 0.0 for k in range(train_iter): image, label = train_loader.next() image_train.d = image label_train.d = label loss_train.forward() solver.zero_grad() loss_train.backward() solver.update() e += categorical_error(pred_train.d, label_train.d) loss += loss_train.data.data.copy() * bs_train e /= train_iter loss /= n_train_samples e = categorical_error(pred_train.d, label_train.d) monitor_loss.add(i, loss) monitor_err.add(i, e) monitor_time.add(i) nn.save_parameters( os.path.join(args.model_save_path, "params_%06d.h5" % (args.train_epochs)) ) # save_nnp_lastepoch contents = save_nnp({"x": image_valid}, {"y": pred_valid}, bs_valid) save.save(os.path.join(args.model_save_path, (args.model+"_result.nnp")), contents)