示例#1
0
def f_train_wrapper(args, pt_model, loss_wrapper, device, \
                    optimizer_wrapper, \
                    train_dataset_wrapper, \
                    val_dataset_wrapper = None, \
                    checkpoint = None):
    """ 
    f_train_wrapper(args, pt_model, loss_wrapper, device, 
                    optimizer_wrapper
                    train_dataset_wrapper, val_dataset_wrapper = None,
                    check_point = None):
      A wrapper to run the training process

    Args:
       args:         argument information given by argpase
       pt_model:     pytorch model (torch.nn.Module)
       loss_wrapper: a wrapper over loss function
                     loss_wrapper.compute(generated, target) 
       device:       torch.device("cuda") or torch.device("cpu")

       optimizer_wrapper: 
           a wrapper over optimizer (defined in op_manager.py)
           optimizer_wrapper.optimizer is torch.optimizer
    
       train_dataset_wrapper: 
           a wrapper over training data set (data_io/default_data_io.py)
           train_dataset_wrapper.get_loader() returns torch.DataSetLoader
       
       val_dataset_wrapper: 
           a wrapper over validation data set (data_io/default_data_io.py)
           it can None.
       
       check_point:
           a check_point that stores every thing to resume training
    """

    nii_display.f_print_w_date("Start model training")

    ##############
    ## Preparation
    ##############

    # get the optimizer
    optimizer_wrapper.print_info()
    optimizer = optimizer_wrapper.optimizer
    lr_scheduler = optimizer_wrapper.lr_scheduler
    epoch_num = optimizer_wrapper.get_epoch_num()
    no_best_epoch_num = optimizer_wrapper.get_no_best_epoch_num()

    # get data loader for training set
    train_dataset_wrapper.print_info()
    train_data_loader = train_dataset_wrapper.get_loader()
    train_seq_num = train_dataset_wrapper.get_seq_num()

    # get the training process monitor
    monitor_trn = nii_monitor.Monitor(epoch_num, train_seq_num)

    # if validation data is provided, get data loader for val set
    if val_dataset_wrapper is not None:
        val_dataset_wrapper.print_info()
        val_data_loader = val_dataset_wrapper.get_loader()
        val_seq_num = val_dataset_wrapper.get_seq_num()
        monitor_val = nii_monitor.Monitor(epoch_num, val_seq_num)
    else:
        monitor_val = None

    # training log information
    train_log = ''

    # prepare for DataParallism if available
    # pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html
    if torch.cuda.device_count() > 1 and args.multi_gpu_data_parallel:
        flag_multi_device = True
        nii_display.f_print("\nUse %d GPUs\n" % (torch.cuda.device_count()))
        # no way to call normtarget_f after pt_model is in DataParallel
        normtarget_f = pt_model.normalize_target
        pt_model = nn.DataParallel(pt_model)
    else:
        nii_display.f_print("\nUse single GPU: %s\n" % \
                            (torch.cuda.get_device_name(device)))
        flag_multi_device = False
        normtarget_f = None
    pt_model.to(device, dtype=nii_dconf.d_dtype)

    # print the network
    nii_nn_tools.f_model_show(pt_model)
    nii_nn_tools.f_loss_show(loss_wrapper)

    ###############################
    ## Resume training if necessary
    ###############################
    # resume training or initialize the model if necessary
    cp_names = nii_nn_manage_conf.CheckPointKey()
    if checkpoint is not None:
        if type(checkpoint) is dict:
            # checkpoint

            # load model parameter and optimizer state
            if cp_names.state_dict in checkpoint:
                # wrap the state_dic in f_state_dict_wrapper
                # in case the model is saved when DataParallel is on
                pt_model.load_state_dict(
                    nii_nn_tools.f_state_dict_wrapper(
                        checkpoint[cp_names.state_dict], flag_multi_device))

            # load optimizer state
            if cp_names.optimizer in checkpoint and \
               not args.ignore_optimizer_statistics_in_trained_model:
                optimizer.load_state_dict(checkpoint[cp_names.optimizer])

            # optionally, load training history
            if not args.ignore_training_history_in_trained_model:
                #nii_display.f_print("Load ")
                if cp_names.trnlog in checkpoint:
                    monitor_trn.load_state_dic(checkpoint[cp_names.trnlog])
                if cp_names.vallog in checkpoint and monitor_val:
                    monitor_val.load_state_dic(checkpoint[cp_names.vallog])
                if cp_names.info in checkpoint:
                    train_log = checkpoint[cp_names.info]
                if cp_names.lr_scheduler in checkpoint and \
                   checkpoint[cp_names.lr_scheduler] and lr_scheduler.f_valid():
                    lr_scheduler.f_load_state_dict(
                        checkpoint[cp_names.lr_scheduler])

                nii_display.f_print("Load check point, resume training")
            else:
                nii_display.f_print("Load pretrained model and optimizer")
        else:
            # only model status
            pt_model.load_state_dict(
                nii_nn_tools.f_state_dict_wrapper(checkpoint,
                                                  flag_multi_device))
            nii_display.f_print("Load pretrained model")

    ######################
    ### User defined setup
    ######################
    if hasattr(pt_model, "other_setups"):
        nii_display.f_print("Conduct User-defined setup")
        pt_model.other_setups()

    # This should be merged with other_setups
    if hasattr(pt_model, "g_pretrained_model_path") and \
       hasattr(pt_model, "g_pretrained_model_prefix"):
        nii_display.f_print("Load pret-rained models as part of this mode")
        nii_nn_tools.f_load_pretrained_model_partially(
            pt_model, pt_model.g_pretrained_model_path,
            pt_model.g_pretrained_model_prefix)

    ######################
    ### Start training
    ######################
    # other variables
    flag_early_stopped = False
    start_epoch = monitor_trn.get_epoch()
    epoch_num = monitor_trn.get_max_epoch()

    # print
    _ = nii_op_display_tk.print_log_head()
    nii_display.f_print_message(train_log, flush=True, end='')

    # loop over multiple epochs
    for epoch_idx in range(start_epoch, epoch_num):

        # training one epoch
        pt_model.train()
        # set validation flag if necessary
        if hasattr(pt_model, 'validation'):
            pt_model.validation = False
            mes = "Warning: model.validation is deprecated, "
            mes += "please use model.flag_validation"
            nii_display.f_print(mes, 'warning')
        if hasattr(pt_model, 'flag_validation'):
            pt_model.flag_validation = False

        f_run_one_epoch(args, pt_model, loss_wrapper, device, \
                        monitor_trn, train_data_loader, \
                        epoch_idx, optimizer, normtarget_f)
        time_trn = monitor_trn.get_time(epoch_idx)
        loss_trn = monitor_trn.get_loss(epoch_idx)

        # if necessary, do validataion
        if val_dataset_wrapper is not None:
            # set eval() if necessary
            if args.eval_mode_for_validation:
                pt_model.eval()

            # set validation flag if necessary
            if hasattr(pt_model, 'validation'):
                pt_model.validation = True
                mes = "Warning: model.validation is deprecated, "
                mes += "please use model.flag_validation"
                nii_display.f_print(mes, 'warning')
            if hasattr(pt_model, 'flag_validation'):
                pt_model.flag_validation = True

            with torch.no_grad():
                f_run_one_epoch(args, pt_model, loss_wrapper, \
                                device, \
                                monitor_val, val_data_loader, \
                                epoch_idx, None, normtarget_f)
            time_val = monitor_val.get_time(epoch_idx)
            loss_val = monitor_val.get_loss(epoch_idx)

            # update lr rate scheduler if necessary
            if lr_scheduler.f_valid():
                lr_scheduler.f_step(loss_val)

        else:
            time_val, loss_val = 0, 0

        if val_dataset_wrapper is not None:
            flag_new_best = monitor_val.is_new_best()
        else:
            flag_new_best = True

        # print information
        train_log += nii_op_display_tk.print_train_info(
            epoch_idx, time_trn, loss_trn, time_val, loss_val, flag_new_best,
            optimizer_wrapper.get_lr_info())

        # save the best model
        if flag_new_best:
            tmp_best_name = nii_nn_tools.f_save_trained_name(args)
            torch.save(pt_model.state_dict(), tmp_best_name)

        # save intermediate model if necessary
        if not args.not_save_each_epoch:
            tmp_model_name = nii_nn_tools.f_save_epoch_name(args, epoch_idx)

            if monitor_val is not None:
                tmp_val_log = monitor_val.get_state_dic()
            else:
                tmp_val_log = None

            if lr_scheduler.f_valid():
                lr_scheduler_state = lr_scheduler.f_state_dict()
            else:
                lr_scheduler_state = None

            # save
            tmp_dic = {
                cp_names.state_dict: pt_model.state_dict(),
                cp_names.info: train_log,
                cp_names.optimizer: optimizer.state_dict(),
                cp_names.trnlog: monitor_trn.get_state_dic(),
                cp_names.vallog: tmp_val_log,
                cp_names.lr_scheduler: lr_scheduler_state
            }
            torch.save(tmp_dic, tmp_model_name)
            if args.verbose == 1:
                nii_display.f_eprint(str(datetime.datetime.now()))
                nii_display.f_eprint("Save {:s}".format(tmp_model_name),
                                     flush=True)

        # Early stopping
        #  note: if LR scheduler is used, early stopping will be
        #  disabled
        if lr_scheduler.f_allow_early_stopping() and \
           monitor_val is not None and \
           monitor_val.should_early_stop(no_best_epoch_num):
            flag_early_stopped = True
            break

    # loop done
    nii_op_display_tk.print_log_tail()
    if flag_early_stopped:
        nii_display.f_print("Training finished by early stopping")
    else:
        nii_display.f_print("Training finished")
    nii_display.f_print("Model is saved to", end='')
    nii_display.f_print("{}".format(nii_nn_tools.f_save_trained_name(args)))
    return
def main():
    """ main(): the default wrapper for training and inference process
    Please prepare config.py and model.py
    """
    # arguments initialization
    args = nii_arg_parse.f_args_parsed()

    #
    nii_warn.f_print_w_date("Start program", level='h')
    nii_warn.f_print("Load module: %s" % (args.module_config))
    nii_warn.f_print("Load module: %s" % (args.module_model))
    prj_conf = importlib.import_module(args.module_config)
    prj_model = importlib.import_module(args.module_model)

    # initialization
    nii_startup.set_random_seed(args.seed, args)
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # prepare data io
    if not args.inference:
        params = {
            'batch_size': args.batch_size,
            'shuffle': args.shuffle,
            'num_workers': args.num_workers
        }

        # Load file list and create data loader
        trn_lst = nii_list_tool.read_list_from_text(prj_conf.trn_list)
        trn_set = nii_dset.NIIDataSetLoader(
            prj_conf.trn_set_name, \
            trn_lst,
            prj_conf.input_dirs, \
            prj_conf.input_exts, \
            prj_conf.input_dims, \
            prj_conf.input_reso, \
            prj_conf.input_norm, \
            prj_conf.output_dirs, \
            prj_conf.output_exts, \
            prj_conf.output_dims, \
            prj_conf.output_reso, \
            prj_conf.output_norm, \
            './',
            params = params,
            truncate_seq = prj_conf.truncate_seq,
            min_seq_len = prj_conf.minimum_len,
            save_mean_std = True,
            wav_samp_rate = prj_conf.wav_samp_rate)

        if prj_conf.val_list is not None:
            val_lst = nii_list_tool.read_list_from_text(prj_conf.val_list)
            val_set = nii_dset.NIIDataSetLoader(
                prj_conf.val_set_name,
                val_lst,
                prj_conf.input_dirs, \
                prj_conf.input_exts, \
                prj_conf.input_dims, \
                prj_conf.input_reso, \
                prj_conf.input_norm, \
                prj_conf.output_dirs, \
                prj_conf.output_exts, \
                prj_conf.output_dims, \
                prj_conf.output_reso, \
                prj_conf.output_norm, \
                './', \
                params = params,
                truncate_seq= prj_conf.truncate_seq,
                min_seq_len = prj_conf.minimum_len,
                save_mean_std = False,
                wav_samp_rate = prj_conf.wav_samp_rate)
        else:
            val_set = None

        # initialize the model and loss function
        model = prj_model.Model(trn_set.get_in_dim(), \
                                trn_set.get_out_dim(), \
                                args, trn_set.get_data_mean_std())
        loss_wrapper = prj_model.Loss(args)

        # initialize the optimizer
        optimizer_wrapper = nii_op_wrapper.OptimizerWrapper(model, args)

        # if necessary, resume training
        if args.trained_model == "":
            checkpoint = None
        else:
            checkpoint = torch.load(args.trained_model)

        # start training
        nii_nn_wrapper.f_train_wrapper(args, model, loss_wrapper, device,
                                       optimizer_wrapper, trn_set, val_set,
                                       checkpoint)
        # done for traing

    else:

        # for inference

        # default, no truncating, no shuffling
        params = {
            'batch_size': args.batch_size,
            'shuffle': False,
            'num_workers': args.num_workers
        }

        if type(prj_conf.test_list) is list:
            t_lst = prj_conf.test_list
        else:
            t_lst = nii_list_tool.read_list_from_text(prj_conf.test_list)
        test_set = nii_dset.NIIDataSetLoader(
            prj_conf.test_set_name, \
            t_lst, \
            prj_conf.test_input_dirs,
            prj_conf.input_exts,
            prj_conf.input_dims,
            prj_conf.input_reso,
            prj_conf.input_norm,
            prj_conf.test_output_dirs,
            prj_conf.output_exts,
            prj_conf.output_dims,
            prj_conf.output_reso,
            prj_conf.output_norm,
            './',
            params = params,
            truncate_seq= None,
            min_seq_len = None,
            save_mean_std = False,
            wav_samp_rate = prj_conf.wav_samp_rate)

        # initialize model
        model = prj_model.Model(test_set.get_in_dim(), \
                                test_set.get_out_dim(), \
                                args)
        if args.trained_model == "":
            print("No model is loaded by ---trained-model for inference")
            print("By default, load %s%s" %
                  (args.save_trained_name, args.save_model_ext))
            checkpoint = torch.load(
                "%s%s" % (args.save_trained_name, args.save_model_ext))
        else:
            checkpoint = torch.load(args.trained_model)

        # do inference and output data
        nii_nn_wrapper.f_inference_wrapper(args, model, device, \
                                           test_set, checkpoint)
    # done
    return
def f_train_wrapper_GAN(
        args, pt_model_G, pt_model_D, loss_wrapper, device, \
        optimizer_G_wrapper, optimizer_D_wrapper, \
        train_dataset_wrapper, \
        val_dataset_wrapper = None, \
        checkpoint_G = None, checkpoint_D = None):
    """ 
    f_train_wrapper_GAN(
       args, pt_model_G, pt_model_D, loss_wrapper, device, 
       optimizer_G_wrapper, optimizer_D_wrapper, 
       train_dataset_wrapper, val_dataset_wrapper = None,
       check_point = None):

      A wrapper to run the training process

    Args:
       args:         argument information given by argpase
       pt_model_G:   generator, pytorch model (torch.nn.Module)
       pt_model_D:   discriminator, pytorch model (torch.nn.Module)
       loss_wrapper: a wrapper over loss functions
                     loss_wrapper.compute_D_real(discriminator_output) 
                     loss_wrapper.compute_D_fake(discriminator_output) 
                     loss_wrapper.compute_G(discriminator_output)
                     loss_wrapper.compute_G(fake, real)

       device:       torch.device("cuda") or torch.device("cpu")

       optimizer_G_wrapper: 
           a optimizer wrapper for generator (defined in op_manager.py)
       optimizer_D_wrapper: 
           a optimizer wrapper for discriminator (defined in op_manager.py)
       
       train_dataset_wrapper: 
           a wrapper over training data set (data_io/default_data_io.py)
           train_dataset_wrapper.get_loader() returns torch.DataSetLoader
       
       val_dataset_wrapper: 
           a wrapper over validation data set (data_io/default_data_io.py)
           it can None.
       
       checkpoint_G:
           a check_point that stores every thing to resume training

       checkpoint_D:
           a check_point that stores every thing to resume training
    """

    nii_display.f_print_w_date("Start model training")

    # get the optimizer
    optimizer_G_wrapper.print_info()
    optimizer_D_wrapper.print_info()
    optimizer_G = optimizer_G_wrapper.optimizer
    optimizer_D = optimizer_D_wrapper.optimizer
    epoch_num = optimizer_G_wrapper.get_epoch_num()
    no_best_epoch_num = optimizer_G_wrapper.get_no_best_epoch_num()

    # get data loader for training set
    train_dataset_wrapper.print_info()
    train_data_loader = train_dataset_wrapper.get_loader()
    train_seq_num = train_dataset_wrapper.get_seq_num()

    # get the training process monitor
    monitor_trn = nii_monitor.Monitor(epoch_num, train_seq_num)

    # if validation data is provided, get data loader for val set
    if val_dataset_wrapper is not None:
        val_dataset_wrapper.print_info()
        val_data_loader = val_dataset_wrapper.get_loader()
        val_seq_num = val_dataset_wrapper.get_seq_num()
        monitor_val = nii_monitor.Monitor(epoch_num, val_seq_num)
    else:
        monitor_val = None

    # training log information
    train_log = ''
    model_tags = ["_G", "_D"]

    # prepare for DataParallism if available
    # pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html
    if torch.cuda.device_count() > 1 and args.multi_gpu_data_parallel:
        nii_display.f_die("data_parallel not implemented for GAN")
    else:
        nii_display.f_print("Use single GPU: %s" % \
                            (torch.cuda.get_device_name(device)))
        flag_multi_device = False
        normtarget_f = None
    pt_model_G.to(device, dtype=nii_dconf.d_dtype)
    pt_model_D.to(device, dtype=nii_dconf.d_dtype)

    # print the network
    nii_display.f_print("Setup generator")
    f_model_show(pt_model_G)
    nii_display.f_print("Setup discriminator")
    f_model_show(pt_model_D)

    # resume training or initialize the model if necessary
    cp_names = CheckPointKey()
    if checkpoint_G is not None or checkpoint_D is not None:
        for checkpoint, optimizer, pt_model, model_name in \
            zip([checkpoint_G, checkpoint_D], [optimizer_G, optimizer_D],
                [pt_model_G, pt_model_D], ["Generator", "Discriminator"]):
            nii_display.f_print("For %s" % (model_name))
            if type(checkpoint) is dict:
                # checkpoint
                # load model parameter and optimizer state
                if cp_names.state_dict in checkpoint:
                    # wrap the state_dic in f_state_dict_wrapper
                    # in case the model is saved when DataParallel is on
                    pt_model.load_state_dict(
                        nii_nn_tools.f_state_dict_wrapper(
                            checkpoint[cp_names.state_dict],
                            flag_multi_device))
                # load optimizer state
                if cp_names.optimizer in checkpoint:
                    optimizer.load_state_dict(checkpoint[cp_names.optimizer])
                # optionally, load training history
                if not args.ignore_training_history_in_trained_model:
                    #nii_display.f_print("Load ")
                    if cp_names.trnlog in checkpoint:
                        monitor_trn.load_state_dic(checkpoint[cp_names.trnlog])
                    if cp_names.vallog in checkpoint and monitor_val:
                        monitor_val.load_state_dic(checkpoint[cp_names.vallog])
                    if cp_names.info in checkpoint:
                        train_log = checkpoint[cp_names.info]
                    nii_display.f_print("Load check point, resume training")
                else:
                    nii_display.f_print("Load pretrained model and optimizer")
            elif checkpoint is not None:
                # only model status
                #pt_model.load_state_dict(checkpoint)
                pt_model.load_state_dict(
                    nii_nn_tools.f_state_dict_wrapper(checkpoint,
                                                      flag_multi_device))
                nii_display.f_print("Load pretrained model")
            else:
                nii_display.f_print("No pretrained model")

    # done for resume training

    # other variables
    flag_early_stopped = False
    start_epoch = monitor_trn.get_epoch()
    epoch_num = monitor_trn.get_max_epoch()

    if hasattr(loss_wrapper, "flag_wgan") and loss_wrapper.flag_wgan:
        f_wrapper_gan_one_epoch = f_run_one_epoch_WGAN
    else:
        f_wrapper_gan_one_epoch = f_run_one_epoch_GAN

    # print
    _ = nii_op_display_tk.print_log_head()
    nii_display.f_print_message(train_log, flush=True, end='')

    # loop over multiple epochs
    for epoch_idx in range(start_epoch, epoch_num):

        # training one epoch
        pt_model_D.train()
        pt_model_G.train()

        f_wrapper_gan_one_epoch(
            args, pt_model_G, pt_model_D,
            loss_wrapper, device, \
            monitor_trn, train_data_loader, \
            epoch_idx, optimizer_G, optimizer_D,
            normtarget_f)

        time_trn = monitor_trn.get_time(epoch_idx)
        loss_trn = monitor_trn.get_loss(epoch_idx)

        # if necessary, do validataion
        if val_dataset_wrapper is not None:
            # set eval() if necessary
            if args.eval_mode_for_validation:
                pt_model_G.eval()
                pt_model_D.eval()
            with torch.no_grad():
                f_wrapper_gan_one_epoch(
                    args, pt_model_G, pt_model_D,
                    loss_wrapper, \
                    device, \
                    monitor_val, val_data_loader, \
                    epoch_idx, None, None, normtarget_f)
            time_val = monitor_val.get_time(epoch_idx)
            loss_val = monitor_val.get_loss(epoch_idx)
        else:
            time_val, loss_val = 0, 0

        if val_dataset_wrapper is not None:
            flag_new_best = monitor_val.is_new_best()
        else:
            flag_new_best = True

        # print information
        train_log += nii_op_display_tk.print_train_info(
            epoch_idx, time_trn, loss_trn, time_val, loss_val, flag_new_best)

        # save the best model
        if flag_new_best:
            for pt_model, model_tag in \
                zip([pt_model_G, pt_model_D], model_tags):
                tmp_best_name = f_save_trained_name_GAN(args, model_tag)
                torch.save(pt_model.state_dict(), tmp_best_name)

        # save intermediate model if necessary
        if not args.not_save_each_epoch:
            # save model discrminator and generator
            for pt_model, optimizer, model_tag in \
                zip([pt_model_G, pt_model_D], [optimizer_G, optimizer_D],
                    model_tags):

                tmp_model_name = f_save_epoch_name_GAN(args, epoch_idx,
                                                       model_tag)
                if monitor_val is not None:
                    tmp_val_log = monitor_val.get_state_dic()
                else:
                    tmp_val_log = None
                # save
                tmp_dic = {
                    cp_names.state_dict: pt_model.state_dict(),
                    cp_names.info: train_log,
                    cp_names.optimizer: optimizer.state_dict(),
                    cp_names.trnlog: monitor_trn.get_state_dic(),
                    cp_names.vallog: tmp_val_log
                }
                torch.save(tmp_dic, tmp_model_name)
                if args.verbose == 1:
                    nii_display.f_eprint(str(datetime.datetime.now()))
                    nii_display.f_eprint("Save {:s}".format(tmp_model_name),
                                         flush=True)

        # early stopping
        if monitor_val is not None and \
           monitor_val.should_early_stop(no_best_epoch_num):
            flag_early_stopped = True
            break

    # loop done

    nii_op_display_tk.print_log_tail()
    if flag_early_stopped:
        nii_display.f_print("Training finished by early stopping")
    else:
        nii_display.f_print("Training finished")
    nii_display.f_print("Model is saved to", end='')
    for model_tag in model_tags:
        nii_display.f_print("{}".format(
            f_save_trained_name_GAN(args, model_tag)))
    return
def f_train_wrapper(args, pt_model, loss_wrapper, device, \
                    optimizer_wrapper, \
                    train_dataset_wrapper, \
                    val_dataset_wrapper = None, \
                    checkpoint = None):
    """ 
    f_train_wrapper(args, pt_model, loss_wrapper, device, 
                    optimizer_wrapper
                    train_dataset_wrapper, val_dataset_wrapper = None,
                    check_point = None):
      A wrapper to run the training process

    Args:
       args:         argument information given by argpase
       pt_model:     pytorch model (torch.nn.Module)
       loss_wrapper: a wrapper over loss function
                     loss_wrapper.compute(generated, target) 
       device:       torch.device("cuda") or torch.device("cpu")

       optimizer_wrapper: 
           a wrapper over optimizer (defined in op_manager.py)
           optimizer_wrapper.optimizer is torch.optimizer
    
       train_dataset_wrapper: 
           a wrapper over training data set (data_io/default_data_io.py)
           train_dataset_wrapper.get_loader() returns torch.DataSetLoader
       
       val_dataset_wrapper: 
           a wrapper over validation data set (data_io/default_data_io.py)
           it can None.
       
       check_point:
           a check_point that stores every thing to resume training
    """

    nii_display.f_print_w_date("Start model training")

    # get the optimizer
    optimizer_wrapper.print_info()
    optimizer = optimizer_wrapper.optimizer
    epoch_num = optimizer_wrapper.get_epoch_num()
    no_best_epoch_num = optimizer_wrapper.get_no_best_epoch_num()

    # get data loader for training set
    train_dataset_wrapper.print_info()
    train_data_loader = train_dataset_wrapper.get_loader()
    train_seq_num = train_dataset_wrapper.get_seq_num()

    # get the training process monitor
    monitor_trn = nii_monitor.Monitor(epoch_num, train_seq_num)

    # if validation data is provided, get data loader for val set
    if val_dataset_wrapper is not None:
        val_dataset_wrapper.print_info()
        val_data_loader = val_dataset_wrapper.get_loader()
        val_seq_num = val_dataset_wrapper.get_seq_num()
        monitor_val = nii_monitor.Monitor(epoch_num, val_seq_num)
    else:
        monitor_val = None

    # training log information
    train_log = ''

    # print the network
    pt_model.to(device, dtype=nii_dconf.d_dtype)
    f_model_show(pt_model)

    # resume training or initialize the model if necessary
    cp_names = CheckPointKey()
    if checkpoint is not None:
        if type(checkpoint) is dict:
            # checkpoint
            if cp_names.state_dict in checkpoint:
                pt_model.load_state_dict(checkpoint[cp_names.state_dict])
            if cp_names.optimizer in checkpoint:
                optimizer.load_state_dict(checkpoint[cp_names.optimizer])
            if cp_names.trnlog in checkpoint:
                monitor_trn.load_state_dic(checkpoint[cp_names.trnlog])
            if cp_names.vallog in checkpoint and monitor_val:
                monitor_val.load_state_dic(checkpoint[cp_names.vallog])
            if cp_names.info in checkpoint:
                train_log = checkpoint[cp_names.info]
            nii_display.f_print("Load check point and resume training")
        else:
            # only model status
            pt_model.load_state_dict(checkpoint)
            nii_display.f_print("Load pre-trained model")

    # other variables
    flag_early_stopped = False
    start_epoch = monitor_trn.get_epoch()
    epoch_num = monitor_trn.get_max_epoch()

    # print
    _ = nii_op_display_tk.print_log_head()
    nii_display.f_print_message(train_log, flush=True, end='')

    # loop over multiple epochs
    for epoch_idx in range(start_epoch, epoch_num):

        # training one epoch
        pt_model.train()
        f_run_one_epoch(args, pt_model, loss_wrapper, device, \
                        monitor_trn, train_data_loader, \
                        epoch_idx, optimizer)
        time_trn = monitor_trn.get_time(epoch_idx)
        loss_trn = monitor_trn.get_loss(epoch_idx)

        # if necessary, do validataion
        if val_dataset_wrapper is not None:
            # set eval() if necessary
            if args.eval_mode_for_validation:
                pt_model.eval()
            with torch.no_grad():
                f_run_one_epoch(args, pt_model, loss_wrapper, \
                                device, \
                                monitor_val, val_data_loader, \
                                epoch_idx, None)
            time_val = monitor_val.get_time(epoch_idx)
            loss_val = monitor_val.get_loss(epoch_idx)
        else:
            time_val, loss_val = 0, 0

        if val_dataset_wrapper is not None:
            flag_new_best = monitor_val.is_new_best()
        else:
            flag_new_best = True

        # print information
        train_log += nii_op_display_tk.print_train_info(epoch_idx, \
                                                        time_trn, \
                                                        loss_trn, \
                                                        time_val, \
                                                        loss_val, \
                                                        flag_new_best)
        # save the best model
        if flag_new_best:
            tmp_best_name = f_save_trained_name(args)
            torch.save(pt_model.state_dict(), tmp_best_name)

        # save intermediate model if necessary
        if not args.not_save_each_epoch:
            tmp_model_name = f_save_epoch_name(args, epoch_idx)
            if monitor_val is not None:
                tmp_val_log = monitor_val.get_state_dic()
            else:
                tmp_val_log = None
            # save
            tmp_dic = {
                cp_names.state_dict: pt_model.state_dict(),
                cp_names.info: train_log,
                cp_names.optimizer: optimizer.state_dict(),
                cp_names.trnlog: monitor_trn.get_state_dic(),
                cp_names.vallog: tmp_val_log
            }
            torch.save(tmp_dic, tmp_model_name)
            if args.verbose == 1:
                nii_display.f_eprint(str(datetime.datetime.now()))
                nii_display.f_eprint("Save {:s}".format(tmp_model_name),
                                     flush=True)

        # early stopping
        if monitor_val is not None and \
           monitor_val.should_early_stop(no_best_epoch_num):
            flag_early_stopped = True
            break

    # loop done
    nii_op_display_tk.print_log_tail()
    if flag_early_stopped:
        nii_display.f_print("Training finished by early stopping")
    else:
        nii_display.f_print("Training finished")
    nii_display.f_print("Model is saved to", end='')
    nii_display.f_print("{}".format(f_save_trained_name(args)))
    return
示例#5
0
 def __init__(self,
              dataset_name, \
              file_list, \
              input_dirs, input_exts, input_dims, input_reso, \
              input_norm, \
              output_dirs, output_exts, output_dims, output_reso, \
              output_norm, \
              stats_path, \
              data_format = '<f4', \
              params = None, \
              truncate_seq = None, \
              min_seq_len = None,
              save_mean_std = True, \
              wav_samp_rate = None):
     """
     NIIDataSetLoader(
              data_set_name,
              file_list,
              input_dirs,
              input_exts,
              input_dims,
              input_reso,
              input_norm,
              output_dirs,
              output_exts,
              output_dims,
              output_reso,
              output_norm,
              stats_path,
              data_format = '<f4',
              params = None,
              truncate_seq = None):
     Args:
         data_set_name: a string to name this dataset
                        this will be used to name the statistics files
                        such as the mean/std for this dataset
         file_list: a list of file name strings (without extension)
         input_dirs: a list of dirs from each input feature is loaded
         input_exts: a list of input feature name extentions
         input_dims: a list of input feature dimensions
         input_reso: a list of input feature temporal resolution,
                     or None
         output_dirs: a list of dirs from each output feature is loaded
         output_exts: a list of output feature name extentions
         output_dims: a list of output feature dimensions
         output_reso: a list of output feature temporal resolution, 
                      or None
         stats_path: path to the directory of statistics(mean/std)
         data_format: method to load the data
                 '<f4' (default): load data as float32m little-endian
                 'htk': load data as htk format
         params: parameter for torch.utils.data.DataLoader
         truncate_seq: None or int, 
                       truncate data sequence into smaller truncks
                       truncate_seq > 0 specifies the trunck length
     Methods:
     get_loader(): return a torch.util.data.DataLoader
     get_dataset(): return a torch.util.data.DataSet
     """
     nii_warn.f_print_w_date("Loading dataset %s" % (dataset_name),
                             level="h")
     
     # create torch.util.data.DataSet
     self.m_dataset = NIIDataSet(dataset_name, \
                                 file_list, \
                                 input_dirs, input_exts, \
                                 input_dims, input_reso, \
                                 input_norm, \
                                 output_dirs, output_exts, \
                                 output_dims, output_reso, \
                                 output_norm, \
                                 stats_path, data_format, \
                                 truncate_seq, min_seq_len,\
                                 save_mean_std, \
                                 wav_samp_rate)
     
     # create torch.util.data.DataLoader
     if params is None:
         tmp_params = nii_dconf.default_loader_conf
     else:
         tmp_params = params
     self.m_loader = torch.utils.data.DataLoader(self.m_dataset,
                                                 **tmp_params)
     # done
     return