def f_run_one_epoch(args, pt_model, loss_wrapper, \ device, monitor, \ data_loader, epoch_idx, optimizer = None, \ target_norm_method = None): """ f_run_one_epoch: run one poech over the dataset (for training or validation sets) Args: args: from argpase pt_model: pytorch model (torch.nn.Module) loss_wrapper: a wrapper over loss function loss_wrapper.compute(generated, target) device: torch.device("cuda") or torch.device("cpu") monitor: defined in op_procfess_monitor.py data_loader: pytorch DataLoader. epoch_idx: int, index of the current epoch optimizer: torch optimizer or None if None, the back propgation will be skipped (for developlement set) target_norm_method: method to normalize target data (by default, use pt_model.normalize_target) """ # timer start_time = time.time() # loop over samples for data_idx, (data_in, data_tar, data_info, idx_orig) in \ enumerate(data_loader): ############# # prepare ############# # idx_orig is the original idx in the dataset # which can be different from data_idx when shuffle = True #idx_orig = idx_orig.numpy()[0] #data_seq_info = data_info[0] # send data to device if optimizer is not None: optimizer.zero_grad() ############ # compute output ############ data_in = data_in.to(device, dtype=nii_dconf.d_dtype) if args.model_forward_with_target: # if model.forward requires (input, target) as arguments # for example, for auto-encoder & autoregressive model if isinstance(data_tar, torch.Tensor): data_tar_tm = data_tar.to(device, dtype=nii_dconf.d_dtype) if args.model_forward_with_file_name: data_gen = pt_model(data_in, data_tar_tm, data_info) else: data_gen = pt_model(data_in, data_tar_tm) else: nii_display.f_print("--model-forward-with-target is set") nii_display.f_die("but data_tar is not loaded") else: if args.model_forward_with_file_name: # specifcal case when model.forward requires data_info data_gen = pt_model(data_in, data_info) else: # normal case for model.forward(input) data_gen = pt_model(data_in) ##################### # compute loss and do back propagate ##################### # Two cases # 1. if loss is defined as pt_model.loss, then let the users do # normalization inside the pt_mode.loss # 2. if loss_wrapper is defined as a class independent from model # there is no way to normalize the data inside the loss_wrapper # because the normalization weight is saved in pt_model if hasattr(pt_model, 'loss'): # case 1, pt_model.loss is available if isinstance(data_tar, torch.Tensor): data_tar = data_tar.to(device, dtype=nii_dconf.d_dtype) else: data_tar = [] loss_computed = pt_model.loss(data_gen, data_tar) else: # case 2, loss is defined independent of pt_model if isinstance(data_tar, torch.Tensor): data_tar = data_tar.to(device, dtype=nii_dconf.d_dtype) # there is no way to normalize the data inside loss # thus, do normalization here if target_norm_method is None: normed_target = pt_model.normalize_target(data_tar) else: normed_target = target_norm_method(data_tar) else: normed_target = [] # return the loss from loss_wrapper # loss_computed may be [[loss_1, loss_2, ...],[flag_1, flag_2,.]] # which contain multiple loss and flags indicating whether # the corresponding loss should be taken into consideration # for early stopping # or # loss_computed may be simply a tensor loss loss_computed = loss_wrapper.compute(data_gen, normed_target) loss_values = [0] # To handle cases where there are multiple loss functions # when loss_comptued is [[loss_1, loss_2, ...],[flag_1, flag_2,.]] # loss: sum of [loss_1, loss_2, ...], for backward() # loss_values: [loss_1.item(), loss_2.item() ..], for logging # loss_flags: [True/False, ...], for logging, # whether loss_n is used for early stopping # when loss_computed is loss # loss: loss # los_vals: [loss.item()] # loss_flags: [True] loss, loss_values, loss_flags = nii_nn_tools.f_process_loss( loss_computed) # Back-propgation using the summed loss if optimizer is not None: # backward propagation loss.backward() # apply gradient clip if args.grad_clip_norm > 0: grad_norm = torch.nn.utils.clip_grad_norm_( pt_model.parameters(), args.grad_clip_norm) # update parameters optimizer.step() # save the training process information to the monitor end_time = time.time() batchsize = len(data_info) for idx, data_seq_info in enumerate(data_info): # loss_value is supposed to be the average loss value # over samples in the the batch, thus, just loss_value # rather loss_value / batchsize monitor.log_loss(loss_values, loss_flags, \ (end_time-start_time) / batchsize, \ data_seq_info, idx_orig.numpy()[idx], \ epoch_idx) # print infor for one sentence if args.verbose == 1: monitor.print_error_for_batch(data_idx*batchsize + idx,\ idx_orig.numpy()[idx], \ epoch_idx) # # start the timer for a new batch start_time = time.time() # Save intermediate model for every n mini-batches (optional). # Note that if we re-start trainining with this intermediate model, # the data will start from the 1st sample, not the one where we stopped if args.save_model_every_n_minibatches > 0 \ and (data_idx+1) % args.save_model_every_n_minibatches == 0 \ and optimizer is not None and data_idx > 0: cp_names = nii_nn_manage_conf.CheckPointKey() tmp_model_name = nii_nn_tools.f_save_epoch_name( args, epoch_idx, '_{:05d}'.format(data_idx + 1)) # save tmp_dic = { cp_names.state_dict: pt_model.state_dict(), cp_names.optimizer: optimizer.state_dict() } torch.save(tmp_dic, tmp_model_name) # loop done return
def f_run_one_epoch_GAN( args, pt_model_G, pt_model_D, loss_wrapper, \ device, monitor, \ data_loader, epoch_idx, optimizer_G = None, optimizer_D = None, \ target_norm_method = None): """ f_run_one_epoch_GAN: run one poech over the dataset (for training or validation sets) Args: args: from argpase pt_model_G: pytorch model (torch.nn.Module) generator pt_model_D: pytorch model (torch.nn.Module) discriminator loss_wrapper: a wrapper over loss function loss_wrapper.compute(generated, target) device: torch.device("cuda") or torch.device("cpu") monitor: defined in op_procfess_monitor.py data_loader: pytorch DataLoader. epoch_idx: int, index of the current epoch optimizer_G: torch optimizer or None, for generator optimizer_D: torch optimizer or None, for discriminator if None, the back propgation will be skipped (for developlement set) target_norm_method: method to normalize target data (by default, use pt_model.normalize_target) """ # timer start_time = time.time() # loop over samples for data_idx, (data_in, data_tar, data_info, idx_orig) in \ enumerate(data_loader): # send data to device if optimizer_G is not None: optimizer_G.zero_grad() if optimizer_D is not None: optimizer_D.zero_grad() # prepare data if isinstance(data_tar, torch.Tensor): data_tar = data_tar.to(device, dtype=nii_dconf.d_dtype) # there is no way to normalize the data inside loss # thus, do normalization here if target_norm_method is None: normed_target = pt_model_G.normalize_target(data_tar) else: normed_target = target_norm_method(data_tar) else: nii_display.f_die("target data is required") # to device (we assume noise will be generated by the model itself) # here we only provide external condition data_in = data_in.to(device, dtype=nii_dconf.d_dtype) ############################ # Update Discriminator ############################ # train with real pt_model_D.zero_grad() d_out_real = pt_model_D(data_tar) errD_real = loss_wrapper.compute_gan_D_real(d_out_real) if optimizer_D is not None: errD_real.backward() # this should be given by pt_model_D or loss wrapper #d_out_real_mean = d_out_real.mean() # train with fake # generate sample if args.model_forward_with_target: # if model.forward requires (input, target) as arguments # for example, for auto-encoder & autoregressive model if isinstance(data_tar, torch.Tensor): data_tar_tm = data_tar.to(device, dtype=nii_dconf.d_dtype) if args.model_forward_with_file_name: data_gen = pt_model_G(data_in, data_tar_tm, data_info) else: data_gen = pt_model_G(data_in, data_tar_tm) else: nii_display.f_print("--model-forward-with-target is set") nii_display.f_die("but data_tar is not loaded") else: if args.model_forward_with_file_name: # specifcal case when model.forward requires data_info data_gen = pt_model_G(data_in, data_info) else: # normal case for model.forward(input) data_gen = pt_model_G(data_in) # data_gen.detach() is required # https://github.com/pytorch/examples/issues/116 d_out_fake = pt_model_D(data_gen.detach()) errD_fake = loss_wrapper.compute_gan_D_fake(d_out_fake) if optimizer_D is not None: errD_fake.backward() errD = errD_real + errD_fake if optimizer_D is not None: optimizer_D.step() ############################ # Update Generator ############################ pt_model_G.zero_grad() d_out_fake_for_G = pt_model_D(data_gen) errG_gan = loss_wrapper.compute_gan_G(d_out_fake_for_G) # if defined, calculate auxilliart loss if hasattr(loss_wrapper, "compute_aux"): errG_aux = loss_wrapper.compute_aux(data_gen, data_tar) else: errG_aux = torch.zeros_like(errG_gan) # if defined, calculate feat-matching loss if hasattr(loss_wrapper, "compute_feat_match"): errG_feat = loss_wrapper.compute_feat_match( d_out_real, d_out_fake_for_G) else: errG_feat = torch.zeros_like(errG_gan) # sum loss for generator errG = errG_gan + errG_aux + errG_feat if optimizer_G is not None: errG.backward() optimizer_G.step() # construct the loss for logging and early stopping # only use errG_aux for early-stopping loss_computed = [[errG_aux, errD_real, errD_fake, errG_gan, errG_feat], [True, False, False, False, False]] # to handle cases where there are multiple loss functions _, loss_vals, loss_flags = nii_nn_tools.f_process_loss(loss_computed) # save the training process information to the monitor end_time = time.time() batchsize = len(data_info) for idx, data_seq_info in enumerate(data_info): # loss_value is supposed to be the average loss value # over samples in the the batch, thus, just loss_value # rather loss_value / batchsize monitor.log_loss(loss_vals, loss_flags, \ (end_time-start_time) / batchsize, \ data_seq_info, idx_orig.numpy()[idx], \ epoch_idx) # print infor for one sentence if args.verbose == 1: monitor.print_error_for_batch(data_idx*batchsize + idx,\ idx_orig.numpy()[idx], \ epoch_idx) # # start the timer for a new batch start_time = time.time() # lopp done return
def f_run_one_epoch(args, pt_model, loss_wrapper, \ device, monitor, \ data_loader, epoch_idx, optimizer = None, \ target_norm_method = None): """ f_run_one_epoch: run one poech over the dataset (for training or validation sets) Args: args: from argpase pt_model: pytorch model (torch.nn.Module) loss_wrapper: a wrapper over loss function loss_wrapper.compute(generated, target) device: torch.device("cuda") or torch.device("cpu") monitor: defined in op_procfess_monitor.py data_loader: pytorch DataLoader. epoch_idx: int, index of the current epoch optimizer: torch optimizer or None if None, the back propgation will be skipped (for developlement set) target_norm_method: method to normalize target data (by default, use pt_model.normalize_target) """ # timer start_time = time.time() # loop over samples pbar = tqdm(data_loader) epoch_num = monitor.get_max_epoch() for data_idx, (data_in, data_tar, data_info, idx_orig) in enumerate(pbar): pbar.set_description("Epoch: {}/{}".format(epoch_idx, epoch_num)) # idx_orig is the original idx in the dataset # which can be different from data_idx when shuffle = True #idx_orig = idx_orig.numpy()[0] #data_seq_info = data_info[0] # send data to device if optimizer is not None: optimizer.zero_grad() # compute data_in = data_in.to(device, dtype=nii_dconf.d_dtype) if args.model_forward_with_target: # if model.forward requires (input, target) as arguments # for example, for auto-encoder & autoregressive model if isinstance(data_tar, torch.Tensor): data_tar_tm = data_tar.to(device, dtype=nii_dconf.d_dtype) if args.model_forward_with_file_name: data_gen = pt_model(data_in, data_tar_tm, data_info) else: data_gen = pt_model(data_in, data_tar_tm) else: nii_display.f_print("--model-forward-with-target is set") nii_display.f_die("but data_tar is not loaded") else: if args.model_forward_with_file_name: # specifcal case when model.forward requires data_info data_gen = pt_model(data_in, data_info) else: # normal case for model.forward(input) data_gen = pt_model(data_in) # compute loss and do back propagate loss_vals = [0] if isinstance(data_tar, torch.Tensor): data_tar = data_tar.to(device, dtype=nii_dconf.d_dtype) # there is no way to normalize the data inside loss # thus, do normalization here if target_norm_method is None: normed_target = pt_model.normalize_target(data_tar) else: normed_target = target_norm_method(data_tar) # return the loss from loss_wrapper # loss_computed may be [[loss_1, loss_2, ...],[flag_1, flag_2,.]] # which contain multiple loss and flags indicating whether # the corresponding loss should be taken into consideration # for early stopping # or # loss_computed may be simply a tensor loss loss_computed = loss_wrapper.compute(data_gen, normed_target) # To handle cases where there are multiple loss functions # when loss_comptued is [[loss_1, loss_2, ...],[flag_1, flag_2,.]] # loss: sum of [loss_1, loss_2, ...], for backward() # loss_vals: [loss_1.item(), loss_2.item() ..], for logging # loss_flags: [True/False, ...], for logging, # whether loss_n is used for early stopping # when loss_computed is loss # loss: loss # los_vals: [loss.item()] # loss_flags: [True] loss, loss_vals, loss_flags = nii_nn_tools.f_process_loss( loss_computed) # Back-propgation using the summed loss if optimizer is not None: loss.backward() optimizer.step() # save the training process information to the monitor end_time = time.time() batchsize = len(data_info) for idx, data_seq_info in enumerate(data_info): # loss_value is supposed to be the average loss value # over samples in the the batch, thus, just loss_value # rather loss_value / batchsize monitor.log_loss(loss_vals, loss_flags, \ (end_time-start_time) / batchsize, \ data_seq_info, idx_orig.numpy()[idx], \ epoch_idx) # print infor for one sentence if args.verbose == 1: monitor.print_error_for_batch(data_idx*batchsize + idx,\ idx_orig.numpy()[idx], \ epoch_idx) # # start the timer for a new batch start_time = time.time() # lopp done pbar.close() return
def f_run_one_epoch_WGAN( args, pt_model_G, pt_model_D, loss_wrapper, \ device, monitor, \ data_loader, epoch_idx, optimizer_G = None, optimizer_D = None, \ target_norm_method = None): """ f_run_one_epoch_WGAN: similar to f_run_one_epoch_GAN, but for WGAN """ # timer start_time = time.time() # number of critic (default 5) if hasattr(args, "wgan-critic-num"): num_critic = args.wgan_critic_num else: num_critic = 5 # clip value if hasattr(args, "wgan-clamp"): wgan_clamp = args.wgan_clamp else: wgan_clamp = 0.01 # loop over samples for data_idx, (data_in, data_tar, data_info, idx_orig) in \ enumerate(data_loader): # send data to device if optimizer_G is not None: optimizer_G.zero_grad() if optimizer_D is not None: optimizer_D.zero_grad() # prepare data if isinstance(data_tar, torch.Tensor): data_tar = data_tar.to(device, dtype=nii_dconf.d_dtype) # there is no way to normalize the data inside loss # thus, do normalization here if target_norm_method is None: normed_target = pt_model_G.normalize_target(data_tar) else: normed_target = target_norm_method(data_tar) else: nii_display.f_die("target data is required") # to device (we assume noise will be generated by the model itself) # here we only provide external condition data_in = data_in.to(device, dtype=nii_dconf.d_dtype) ############################ # Update Discriminator ############################ # train with real pt_model_D.zero_grad() d_out_real = pt_model_D(data_tar) errD_real = loss_wrapper.compute_gan_D_real(d_out_real) if optimizer_D is not None: errD_real.backward() d_out_real_mean = d_out_real.mean() # train with fake # generate sample if args.model_forward_with_target: # if model.forward requires (input, target) as arguments # for example, for auto-encoder & autoregressive model if isinstance(data_tar, torch.Tensor): data_tar_tm = data_tar.to(device, dtype=nii_dconf.d_dtype) if args.model_forward_with_file_name: data_gen = pt_model_G(data_in, data_tar_tm, data_info) else: data_gen = pt_model_G(data_in, data_tar_tm) else: nii_display.f_print("--model-forward-with-target is set") nii_display.f_die("but data_tar is not loaded") else: if args.model_forward_with_file_name: # specifcal case when model.forward requires data_info data_gen = pt_model_G(data_in, data_info) else: # normal case for model.forward(input) data_gen = pt_model_G(data_in) # data_gen.detach() is required # https://github.com/pytorch/examples/issues/116 d_out_fake = pt_model_D(data_gen.detach()) errD_fake = loss_wrapper.compute_gan_D_fake(d_out_fake) if optimizer_D is not None: errD_fake.backward() d_out_fake_mean = d_out_fake.mean() errD = errD_real + errD_fake if optimizer_D is not None: optimizer_D.step() # clip weights of discriminator for p in pt_model_D.parameters(): p.data.clamp_(-wgan_clamp, wgan_clamp) ############################ # Update Generator ############################ pt_model_G.zero_grad() d_out_fake_for_G = pt_model_D(data_gen) errG_gan = loss_wrapper.compute_gan_G(d_out_fake_for_G) errG_aux = loss_wrapper.compute_aux(data_gen, data_tar) errG = errG_gan + errG_aux # only update after num_crictic iterations on discriminator if data_idx % num_critic == 0 and optimizer_G is not None: errG.backward() optimizer_G.step() d_out_fake_for_G_mean = d_out_fake_for_G.mean() # construct the loss for logging and early stopping # only use errG_aux for early-stopping loss_computed = [[ errG_aux, errG_gan, errD_real, errD_fake, d_out_real_mean, d_out_fake_mean, d_out_fake_for_G_mean ], [True, False, False, False, False, False, False]] # to handle cases where there are multiple loss functions loss, loss_vals, loss_flags = nii_nn_tools.f_process_loss( loss_computed) # save the training process information to the monitor end_time = time.time() batchsize = len(data_info) for idx, data_seq_info in enumerate(data_info): # loss_value is supposed to be the average loss value # over samples in the the batch, thus, just loss_value # rather loss_value / batchsize monitor.log_loss(loss_vals, loss_flags, \ (end_time-start_time) / batchsize, \ data_seq_info, idx_orig.numpy()[idx], \ epoch_idx) # print infor for one sentence if args.verbose == 1: monitor.print_error_for_batch(data_idx*batchsize + idx,\ idx_orig.numpy()[idx], \ epoch_idx) # # start the timer for a new batch start_time = time.time() # lopp done return