Esempio n. 1
0
def f_run_one_epoch(args,
                    pt_model, loss_wrapper, \
                    device, monitor,  \
                    data_loader, epoch_idx, optimizer = None, \
                    target_norm_method = None):
    """
    f_run_one_epoch: 
       run one poech over the dataset (for training or validation sets)

    Args:
       args:         from argpase
       pt_model:     pytorch model (torch.nn.Module)
       loss_wrapper: a wrapper over loss function
                     loss_wrapper.compute(generated, target) 
       device:       torch.device("cuda") or torch.device("cpu")
       monitor:      defined in op_procfess_monitor.py
       data_loader:  pytorch DataLoader. 
       epoch_idx:    int, index of the current epoch
       optimizer:    torch optimizer or None
                     if None, the back propgation will be skipped
                     (for developlement set)
       target_norm_method: method to normalize target data
                           (by default, use pt_model.normalize_target)
    """
    # timer
    start_time = time.time()

    # loop over samples
    for data_idx, (data_in, data_tar, data_info, idx_orig) in \
        enumerate(data_loader):

        #############
        # prepare
        #############
        # idx_orig is the original idx in the dataset
        # which can be different from data_idx when shuffle = True
        #idx_orig = idx_orig.numpy()[0]
        #data_seq_info = data_info[0]

        # send data to device
        if optimizer is not None:
            optimizer.zero_grad()

        ############
        # compute output
        ############
        data_in = data_in.to(device, dtype=nii_dconf.d_dtype)
        if args.model_forward_with_target:
            # if model.forward requires (input, target) as arguments
            # for example, for auto-encoder & autoregressive model
            if isinstance(data_tar, torch.Tensor):
                data_tar_tm = data_tar.to(device, dtype=nii_dconf.d_dtype)
                if args.model_forward_with_file_name:
                    data_gen = pt_model(data_in, data_tar_tm, data_info)
                else:
                    data_gen = pt_model(data_in, data_tar_tm)
            else:
                nii_display.f_print("--model-forward-with-target is set")
                nii_display.f_die("but data_tar is not loaded")
        else:
            if args.model_forward_with_file_name:
                # specifcal case when model.forward requires data_info
                data_gen = pt_model(data_in, data_info)
            else:
                # normal case for model.forward(input)
                data_gen = pt_model(data_in)

        #####################
        # compute loss and do back propagate
        #####################

        # Two cases
        # 1. if loss is defined as pt_model.loss, then let the users do
        #    normalization inside the pt_mode.loss
        # 2. if loss_wrapper is defined as a class independent from model
        #    there is no way to normalize the data inside the loss_wrapper
        #    because the normalization weight is saved in pt_model

        if hasattr(pt_model, 'loss'):
            # case 1, pt_model.loss is available
            if isinstance(data_tar, torch.Tensor):
                data_tar = data_tar.to(device, dtype=nii_dconf.d_dtype)
            else:
                data_tar = []

            loss_computed = pt_model.loss(data_gen, data_tar)
        else:
            # case 2, loss is defined independent of pt_model
            if isinstance(data_tar, torch.Tensor):
                data_tar = data_tar.to(device, dtype=nii_dconf.d_dtype)
                # there is no way to normalize the data inside loss
                # thus, do normalization here
                if target_norm_method is None:
                    normed_target = pt_model.normalize_target(data_tar)
                else:
                    normed_target = target_norm_method(data_tar)
            else:
                normed_target = []

            # return the loss from loss_wrapper
            # loss_computed may be [[loss_1, loss_2, ...],[flag_1, flag_2,.]]
            #   which contain multiple loss and flags indicating whether
            #   the corresponding loss should be taken into consideration
            #   for early stopping
            # or
            # loss_computed may be simply a tensor loss
            loss_computed = loss_wrapper.compute(data_gen, normed_target)

        loss_values = [0]
        # To handle cases where there are multiple loss functions
        # when loss_comptued is [[loss_1, loss_2, ...],[flag_1, flag_2,.]]
        #   loss: sum of [loss_1, loss_2, ...], for backward()
        #   loss_values: [loss_1.item(), loss_2.item() ..], for logging
        #   loss_flags: [True/False, ...], for logging,
        #               whether loss_n is used for early stopping
        # when loss_computed is loss
        #   loss: loss
        #   los_vals: [loss.item()]
        #   loss_flags: [True]
        loss, loss_values, loss_flags = nii_nn_tools.f_process_loss(
            loss_computed)

        # Back-propgation using the summed loss
        if optimizer is not None:
            # backward propagation
            loss.backward()

            # apply gradient clip
            if args.grad_clip_norm > 0:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    pt_model.parameters(), args.grad_clip_norm)

            # update parameters
            optimizer.step()

        # save the training process information to the monitor
        end_time = time.time()
        batchsize = len(data_info)
        for idx, data_seq_info in enumerate(data_info):
            # loss_value is supposed to be the average loss value
            # over samples in the the batch, thus, just loss_value
            # rather loss_value / batchsize
            monitor.log_loss(loss_values, loss_flags, \
                             (end_time-start_time) / batchsize, \
                             data_seq_info, idx_orig.numpy()[idx], \
                             epoch_idx)
            # print infor for one sentence
            if args.verbose == 1:
                monitor.print_error_for_batch(data_idx*batchsize + idx,\
                                              idx_orig.numpy()[idx], \
                                              epoch_idx)
            #
        # start the timer for a new batch
        start_time = time.time()

        # Save intermediate model for every n mini-batches (optional).
        # Note that if we re-start trainining with this intermediate model,
        #  the data will start from the 1st sample, not the one where we stopped
        if args.save_model_every_n_minibatches > 0 \
           and (data_idx+1) % args.save_model_every_n_minibatches == 0 \
           and optimizer is not None and data_idx > 0:
            cp_names = nii_nn_manage_conf.CheckPointKey()
            tmp_model_name = nii_nn_tools.f_save_epoch_name(
                args, epoch_idx, '_{:05d}'.format(data_idx + 1))
            # save
            tmp_dic = {
                cp_names.state_dict: pt_model.state_dict(),
                cp_names.optimizer: optimizer.state_dict()
            }
            torch.save(tmp_dic, tmp_model_name)

    # loop done
    return
def f_run_one_epoch_GAN(
        args, pt_model_G, pt_model_D,
        loss_wrapper, \
        device, monitor,  \
        data_loader, epoch_idx,
        optimizer_G = None, optimizer_D = None, \
        target_norm_method = None):
    """
    f_run_one_epoch_GAN: 
       run one poech over the dataset (for training or validation sets)

    Args:
       args:         from argpase
       pt_model_G:   pytorch model (torch.nn.Module) generator
       pt_model_D:   pytorch model (torch.nn.Module) discriminator
       loss_wrapper: a wrapper over loss function
                     loss_wrapper.compute(generated, target) 
       device:       torch.device("cuda") or torch.device("cpu")
       monitor:      defined in op_procfess_monitor.py
       data_loader:  pytorch DataLoader. 
       epoch_idx:    int, index of the current epoch
       optimizer_G:  torch optimizer or None, for generator
       optimizer_D:  torch optimizer or None, for discriminator
                     if None, the back propgation will be skipped
                     (for developlement set)
       target_norm_method: method to normalize target data
                           (by default, use pt_model.normalize_target)
    """
    # timer
    start_time = time.time()

    # loop over samples
    for data_idx, (data_in, data_tar, data_info, idx_orig) in \
        enumerate(data_loader):

        # send data to device
        if optimizer_G is not None:
            optimizer_G.zero_grad()
        if optimizer_D is not None:
            optimizer_D.zero_grad()

        # prepare data
        if isinstance(data_tar, torch.Tensor):
            data_tar = data_tar.to(device, dtype=nii_dconf.d_dtype)
            # there is no way to normalize the data inside loss
            # thus, do normalization here
            if target_norm_method is None:
                normed_target = pt_model_G.normalize_target(data_tar)
            else:
                normed_target = target_norm_method(data_tar)
        else:
            nii_display.f_die("target data is required")

        # to device (we assume noise will be generated by the model itself)
        # here we only provide external condition
        data_in = data_in.to(device, dtype=nii_dconf.d_dtype)

        ############################
        # Update Discriminator
        ############################
        # train with real
        pt_model_D.zero_grad()
        d_out_real = pt_model_D(data_tar)
        errD_real = loss_wrapper.compute_gan_D_real(d_out_real)
        if optimizer_D is not None:
            errD_real.backward()

        # this should be given by pt_model_D or loss wrapper
        #d_out_real_mean = d_out_real.mean()

        # train with fake
        #  generate sample
        if args.model_forward_with_target:
            # if model.forward requires (input, target) as arguments
            # for example, for auto-encoder & autoregressive model
            if isinstance(data_tar, torch.Tensor):
                data_tar_tm = data_tar.to(device, dtype=nii_dconf.d_dtype)
                if args.model_forward_with_file_name:
                    data_gen = pt_model_G(data_in, data_tar_tm, data_info)
                else:
                    data_gen = pt_model_G(data_in, data_tar_tm)
            else:
                nii_display.f_print("--model-forward-with-target is set")
                nii_display.f_die("but data_tar is not loaded")
        else:
            if args.model_forward_with_file_name:
                # specifcal case when model.forward requires data_info
                data_gen = pt_model_G(data_in, data_info)
            else:
                # normal case for model.forward(input)
                data_gen = pt_model_G(data_in)

        # data_gen.detach() is required
        # https://github.com/pytorch/examples/issues/116
        d_out_fake = pt_model_D(data_gen.detach())
        errD_fake = loss_wrapper.compute_gan_D_fake(d_out_fake)
        if optimizer_D is not None:
            errD_fake.backward()

        errD = errD_real + errD_fake
        if optimizer_D is not None:
            optimizer_D.step()

        ############################
        # Update Generator
        ############################
        pt_model_G.zero_grad()
        d_out_fake_for_G = pt_model_D(data_gen)
        errG_gan = loss_wrapper.compute_gan_G(d_out_fake_for_G)

        # if defined, calculate auxilliart loss
        if hasattr(loss_wrapper, "compute_aux"):
            errG_aux = loss_wrapper.compute_aux(data_gen, data_tar)
        else:
            errG_aux = torch.zeros_like(errG_gan)

        # if defined, calculate feat-matching loss
        if hasattr(loss_wrapper, "compute_feat_match"):
            errG_feat = loss_wrapper.compute_feat_match(
                d_out_real, d_out_fake_for_G)
        else:
            errG_feat = torch.zeros_like(errG_gan)

        # sum loss for generator
        errG = errG_gan + errG_aux + errG_feat

        if optimizer_G is not None:
            errG.backward()
            optimizer_G.step()

        # construct the loss for logging and early stopping
        # only use errG_aux for early-stopping
        loss_computed = [[errG_aux, errD_real, errD_fake, errG_gan, errG_feat],
                         [True, False, False, False, False]]

        # to handle cases where there are multiple loss functions
        _, loss_vals, loss_flags = nii_nn_tools.f_process_loss(loss_computed)

        # save the training process information to the monitor
        end_time = time.time()
        batchsize = len(data_info)
        for idx, data_seq_info in enumerate(data_info):
            # loss_value is supposed to be the average loss value
            # over samples in the the batch, thus, just loss_value
            # rather loss_value / batchsize
            monitor.log_loss(loss_vals, loss_flags, \
                             (end_time-start_time) / batchsize, \
                             data_seq_info, idx_orig.numpy()[idx], \
                             epoch_idx)
            # print infor for one sentence
            if args.verbose == 1:
                monitor.print_error_for_batch(data_idx*batchsize + idx,\
                                              idx_orig.numpy()[idx], \
                                              epoch_idx)
            #
        # start the timer for a new batch
        start_time = time.time()

    # lopp done
    return
Esempio n. 3
0
def f_run_one_epoch(args,
                    pt_model, loss_wrapper, \
                    device, monitor,  \
                    data_loader, epoch_idx, optimizer = None, \
                    target_norm_method = None):
    """
    f_run_one_epoch: 
       run one poech over the dataset (for training or validation sets)

    Args:
       args:         from argpase
       pt_model:     pytorch model (torch.nn.Module)
       loss_wrapper: a wrapper over loss function
                     loss_wrapper.compute(generated, target) 
       device:       torch.device("cuda") or torch.device("cpu")
       monitor:      defined in op_procfess_monitor.py
       data_loader:  pytorch DataLoader. 
       epoch_idx:    int, index of the current epoch
       optimizer:    torch optimizer or None
                     if None, the back propgation will be skipped
                     (for developlement set)
       target_norm_method: method to normalize target data
                           (by default, use pt_model.normalize_target)
    """
    # timer
    start_time = time.time()

    # loop over samples
    pbar = tqdm(data_loader)
    epoch_num = monitor.get_max_epoch()
    for data_idx, (data_in, data_tar, data_info, idx_orig) in enumerate(pbar):
        pbar.set_description("Epoch: {}/{}".format(epoch_idx, epoch_num))
        # idx_orig is the original idx in the dataset
        # which can be different from data_idx when shuffle = True
        #idx_orig = idx_orig.numpy()[0]
        #data_seq_info = data_info[0]

        # send data to device
        if optimizer is not None:
            optimizer.zero_grad()

        # compute
        data_in = data_in.to(device, dtype=nii_dconf.d_dtype)
        if args.model_forward_with_target:
            # if model.forward requires (input, target) as arguments
            # for example, for auto-encoder & autoregressive model
            if isinstance(data_tar, torch.Tensor):
                data_tar_tm = data_tar.to(device, dtype=nii_dconf.d_dtype)
                if args.model_forward_with_file_name:
                    data_gen = pt_model(data_in, data_tar_tm, data_info)
                else:
                    data_gen = pt_model(data_in, data_tar_tm)
            else:
                nii_display.f_print("--model-forward-with-target is set")
                nii_display.f_die("but data_tar is not loaded")
        else:
            if args.model_forward_with_file_name:
                # specifcal case when model.forward requires data_info
                data_gen = pt_model(data_in, data_info)
            else:
                # normal case for model.forward(input)
                data_gen = pt_model(data_in)

        # compute loss and do back propagate
        loss_vals = [0]
        if isinstance(data_tar, torch.Tensor):
            data_tar = data_tar.to(device, dtype=nii_dconf.d_dtype)
            # there is no way to normalize the data inside loss
            # thus, do normalization here
            if target_norm_method is None:
                normed_target = pt_model.normalize_target(data_tar)
            else:
                normed_target = target_norm_method(data_tar)

            # return the loss from loss_wrapper
            # loss_computed may be [[loss_1, loss_2, ...],[flag_1, flag_2,.]]
            #   which contain multiple loss and flags indicating whether
            #   the corresponding loss should be taken into consideration
            #   for early stopping
            # or
            # loss_computed may be simply a tensor loss
            loss_computed = loss_wrapper.compute(data_gen, normed_target)

            # To handle cases where there are multiple loss functions
            # when loss_comptued is [[loss_1, loss_2, ...],[flag_1, flag_2,.]]
            #   loss: sum of [loss_1, loss_2, ...], for backward()
            #   loss_vals: [loss_1.item(), loss_2.item() ..], for logging
            #   loss_flags: [True/False, ...], for logging,
            #               whether loss_n is used for early stopping
            # when loss_computed is loss
            #   loss: loss
            #   los_vals: [loss.item()]
            #   loss_flags: [True]
            loss, loss_vals, loss_flags = nii_nn_tools.f_process_loss(
                loss_computed)

            # Back-propgation using the summed loss
            if optimizer is not None:
                loss.backward()
                optimizer.step()

        # save the training process information to the monitor
        end_time = time.time()
        batchsize = len(data_info)
        for idx, data_seq_info in enumerate(data_info):
            # loss_value is supposed to be the average loss value
            # over samples in the the batch, thus, just loss_value
            # rather loss_value / batchsize
            monitor.log_loss(loss_vals, loss_flags, \
                             (end_time-start_time) / batchsize, \
                             data_seq_info, idx_orig.numpy()[idx], \
                             epoch_idx)
            # print infor for one sentence
            if args.verbose == 1:
                monitor.print_error_for_batch(data_idx*batchsize + idx,\
                                              idx_orig.numpy()[idx], \
                                              epoch_idx)
            #
        # start the timer for a new batch
        start_time = time.time()

    # lopp done
    pbar.close()
    return
def f_run_one_epoch_WGAN(
        args, pt_model_G, pt_model_D,
        loss_wrapper, \
        device, monitor,  \
        data_loader, epoch_idx,
        optimizer_G = None, optimizer_D = None, \
        target_norm_method = None):
    """
    f_run_one_epoch_WGAN: 
       similar to f_run_one_epoch_GAN, but for WGAN
    """
    # timer
    start_time = time.time()

    # number of critic (default 5)
    if hasattr(args, "wgan-critic-num"):
        num_critic = args.wgan_critic_num
    else:
        num_critic = 5
    # clip value
    if hasattr(args, "wgan-clamp"):
        wgan_clamp = args.wgan_clamp
    else:
        wgan_clamp = 0.01

    # loop over samples
    for data_idx, (data_in, data_tar, data_info, idx_orig) in \
        enumerate(data_loader):

        # send data to device
        if optimizer_G is not None:
            optimizer_G.zero_grad()
        if optimizer_D is not None:
            optimizer_D.zero_grad()

        # prepare data
        if isinstance(data_tar, torch.Tensor):
            data_tar = data_tar.to(device, dtype=nii_dconf.d_dtype)
            # there is no way to normalize the data inside loss
            # thus, do normalization here
            if target_norm_method is None:
                normed_target = pt_model_G.normalize_target(data_tar)
            else:
                normed_target = target_norm_method(data_tar)
        else:
            nii_display.f_die("target data is required")

        # to device (we assume noise will be generated by the model itself)
        # here we only provide external condition
        data_in = data_in.to(device, dtype=nii_dconf.d_dtype)

        ############################
        # Update Discriminator
        ############################
        # train with real
        pt_model_D.zero_grad()
        d_out_real = pt_model_D(data_tar)
        errD_real = loss_wrapper.compute_gan_D_real(d_out_real)
        if optimizer_D is not None:
            errD_real.backward()
        d_out_real_mean = d_out_real.mean()

        # train with fake
        #  generate sample
        if args.model_forward_with_target:
            # if model.forward requires (input, target) as arguments
            # for example, for auto-encoder & autoregressive model
            if isinstance(data_tar, torch.Tensor):
                data_tar_tm = data_tar.to(device, dtype=nii_dconf.d_dtype)
                if args.model_forward_with_file_name:
                    data_gen = pt_model_G(data_in, data_tar_tm, data_info)
                else:
                    data_gen = pt_model_G(data_in, data_tar_tm)
            else:
                nii_display.f_print("--model-forward-with-target is set")
                nii_display.f_die("but data_tar is not loaded")
        else:
            if args.model_forward_with_file_name:
                # specifcal case when model.forward requires data_info
                data_gen = pt_model_G(data_in, data_info)
            else:
                # normal case for model.forward(input)
                data_gen = pt_model_G(data_in)

        # data_gen.detach() is required
        # https://github.com/pytorch/examples/issues/116
        d_out_fake = pt_model_D(data_gen.detach())
        errD_fake = loss_wrapper.compute_gan_D_fake(d_out_fake)
        if optimizer_D is not None:
            errD_fake.backward()
        d_out_fake_mean = d_out_fake.mean()

        errD = errD_real + errD_fake
        if optimizer_D is not None:
            optimizer_D.step()

        # clip weights of discriminator
        for p in pt_model_D.parameters():
            p.data.clamp_(-wgan_clamp, wgan_clamp)

        ############################
        # Update Generator
        ############################
        pt_model_G.zero_grad()
        d_out_fake_for_G = pt_model_D(data_gen)
        errG_gan = loss_wrapper.compute_gan_G(d_out_fake_for_G)
        errG_aux = loss_wrapper.compute_aux(data_gen, data_tar)
        errG = errG_gan + errG_aux

        # only update after num_crictic iterations on discriminator
        if data_idx % num_critic == 0 and optimizer_G is not None:
            errG.backward()
            optimizer_G.step()

        d_out_fake_for_G_mean = d_out_fake_for_G.mean()

        # construct the loss for logging and early stopping
        # only use errG_aux for early-stopping
        loss_computed = [[
            errG_aux, errG_gan, errD_real, errD_fake, d_out_real_mean,
            d_out_fake_mean, d_out_fake_for_G_mean
        ], [True, False, False, False, False, False, False]]

        # to handle cases where there are multiple loss functions
        loss, loss_vals, loss_flags = nii_nn_tools.f_process_loss(
            loss_computed)

        # save the training process information to the monitor
        end_time = time.time()
        batchsize = len(data_info)
        for idx, data_seq_info in enumerate(data_info):
            # loss_value is supposed to be the average loss value
            # over samples in the the batch, thus, just loss_value
            # rather loss_value / batchsize
            monitor.log_loss(loss_vals, loss_flags, \
                             (end_time-start_time) / batchsize, \
                             data_seq_info, idx_orig.numpy()[idx], \
                             epoch_idx)
            # print infor for one sentence
            if args.verbose == 1:
                monitor.print_error_for_batch(data_idx*batchsize + idx,\
                                              idx_orig.numpy()[idx], \
                                              epoch_idx)
            #
        # start the timer for a new batch
        start_time = time.time()

    # lopp done
    return