Exemplo n.º 1
0
def _get_current_parameter(args, config):
    def convert_to_info(config):
        class Info:
            pass

        ret = Info()
        ret.optimizers = OrderedDict()
        for name, opt in config.optimizers.items():
            ret.optimizers[name] = opt.optimizer
        return ret

    best_error, best_epoch = callback.get_best_from_status(args)

    globname = os.path.join(args.outdir, 'results_current_*.nnp')
    exists = glob.glob(globname)

    if len(exists) > 0:
        ex_list = {}

        info = convert_to_info(config)
        for ex in exists:
            n = int(ex.rsplit('_', 1)[1].rsplit('.', 1)[0])
            ex_list[n] = ex

        last_epoch = sorted(ex_list.keys(), reverse=True)[0]
        last_parameter = ex_list[last_epoch]
        logger.log(
            99, "Load parameter from [{}]".format(
                os.path.basename(last_parameter)))
        #load.load([last_parameter], parameter_only=True)
        load_train_state(last_parameter, info)
        return last_epoch, best_epoch, best_error

    return 0, best_epoch, best_error
Exemplo n.º 2
0
def train_command(args):
    if args.ooc_gpu_memory_size is not None:
        ooc_gpu_memory_size = str_to_num(args.ooc_gpu_memory_size)
        if ooc_gpu_memory_size < 0:
            logger.log(
                99,
                f'Fatal error. invalid ooc_gpu_memory_size [{args.ooc_gpu_memory_size}].'
            )
            return False
        args.ooc_gpu_memory_size = ooc_gpu_memory_size
    if args.ooc_window_length is not None:
        ooc_window_length = str_to_num(args.ooc_window_length)
        if ooc_window_length < 0:
            logger.log(
                99,
                f'Fatal error. invalid ooc_window_length [{args.ooc_window_length}].'
            )
            return False
        args.ooc_window_length = ooc_window_length

    callback.update_status(args)

    if single_or_rankzero():
        configure_progress(os.path.join(args.outdir, 'progress.txt'))

    info = load.load([args.config],
                     prepare_data_iterator=None,
                     exclude_parameter=True,
                     context=args.context)

    # Check dataset uri is empty.
    dataset_error = False
    for dataset in info.datasets.values():
        if dataset.uri.strip() == '':
            dataset_error = True
    if dataset_error:
        logger.log(99, 'Fatal error. Dataset URI is empty.')
        return False

    class TrainConfig:
        pass

    config = TrainConfig()
    config.timelimit = -1
    if args.param:
        # If this parameter file contains optimizer information
        # we need to info to recovery.
        #load.load([args.param], parameter_only=True)
        load_train_state(args.param, info)

    config.timelimit = callback.get_timelimit(args)

    config.global_config = info.global_config
    config.training_config = info.training_config

    if single_or_rankzero():
        logger.log(99, 'Train with contexts {}'.format(available_contexts))

    class OptConfig:
        pass

    config.optimizers = OrderedDict()
    for name, opt in info.optimizers.items():
        o = OptConfig()
        o.optimizer = opt
        o.data_iterators = []
        config.optimizers[name] = o

    class MonConfig:
        pass

    config.monitors = OrderedDict()
    for name, mon in info.monitors.items():
        m = MonConfig()
        m.monitor = mon
        m.data_iterators = []
        config.monitors[name] = m

    # Training
    comm = current_communicator()
    config.training_config.iter_per_epoch //= comm.size if comm else 1
    max_iteration = config.training_config.max_epoch * \
        config.training_config.iter_per_epoch

    global _save_parameter_info
    _save_parameter_info = {}
    _, config_ext = os.path.splitext(args.config)
    if config_ext == '.prototxt' or config_ext == '.nntxt':
        _save_parameter_info['config'] = args.config
    elif config_ext == '.nnp':
        with zipfile.ZipFile(args.config, 'r') as nnp:
            for name in nnp.namelist():
                _, ext = os.path.splitext(name)
                if ext == '.nntxt' or ext == '.prototxt':
                    nnp.extract(name, args.outdir)
                    _save_parameter_info['config'] = os.path.join(
                        args.outdir, name)

    result = False
    restart = False
    if max_iteration > 0:
        rng = np.random.RandomState(comm.rank if comm else 0)
        with ExitStack() as stack:
            # Create data_iterator instance only once for each dataset in optimizers
            optimizer_data_iterators = {}
            for name, o in config.optimizers.items():
                for di in o.optimizer.data_iterators.values():
                    if di not in optimizer_data_iterators:
                        di_instance = stack.enter_context(di())
                        if comm and comm.size > 1:
                            di_instance = di_instance.slice(
                                rng, comm.size, comm.rank)
                        optimizer_data_iterators[di] = di_instance
                    else:
                        di_instance = optimizer_data_iterators[di]
                    o.data_iterators.append(di_instance)

            # Create data_iterator instance only once for each dataset in monitors
            monitor_data_iterators = {}
            for name, m in config.monitors.items():
                for di in m.monitor.data_iterators.values():
                    if di not in monitor_data_iterators:
                        di_instance = stack.enter_context(di())
                        if comm and comm.size > 1:
                            di_instance = di_instance.slice(
                                rng, comm.size, comm.rank)
                        monitor_data_iterators[di] = di_instance
                    else:
                        di_instance = monitor_data_iterators[di]
                    m.data_iterators.append(di_instance)
            monitor_data_iterators.update(optimizer_data_iterators)

            result, restart = _train(args, config)
    else:
        # save parameters without training (0 epoch learning)
        logger.log(99, '0 epoch learning. (Just save parameter.)')
        if single_or_rankzero():
            _save_parameters(args, None, 0, config, True)
        result = True

    if single_or_rankzero() and not restart:
        if result:
            logger.log(99, 'Training Completed.')
            callback.update_status('finished')
        else:
            logger.log(99, 'Training Incompleted.')
            callback.update_status('failed')
    if single_or_rankzero():
        progress(None)
    return True
Exemplo n.º 3
0
def test_resume_suspend_equivalence(nntxt_idx, parameter_format,
                                    dataset_sample_num, batch_size):
    '''These cases tends to test equivalence before and after refactoring.
    '''
    verbose = True
    a_few_iter = 10
    half_iter = 5
    output_network_topology = False
    with generate_case_from_nntxt_str(NNTXT_EQUIVALENCE_CASES[nntxt_idx],
                                      parameter_format, dataset_sample_num,
                                      batch_size) as nnp_file:
        with create_temp_with_dir(
                "saved_parameter.nnp") as saved_parameter_nnp:

            class Callback:
                pass

            class ModelSaver:
                def __init__(self, info):
                    self.info = info

                def __call__(self, config):
                    if config.iter != half_iter:
                        return
                    _save_parameters(saved_parameter_nnp, config,
                                     NNTXT_EQUIVALENCE_CASES[nntxt_idx])

            new_config = TrainConfig()
            new_config.start_iteration = 0
            new_config.end_iteration = a_few_iter
            new_config.save_optimizer_variable = False
            new_config.save_evaluation_variable = False
            new_cb = Callback()
            new_cb.forward = lambda x: x.target.forward(clear_no_need_grad=True
                                                        )
            new_cb.backward = lambda x, b: x.target.backward()
            new_config.cb = new_cb
            new_config.impl = "ref"

            ref_result = []
            ref_info = load.load(nnp_file, batch_size=batch_size)
            print("load.load")

            if output_network_topology:
                for n, opt in ref_info.optimizers.items():
                    print(n)
                    opt.network.execute_on_proto(Verifier())

            new_config.on_iter = ModelSaver(ref_info)
            for cost, error in partial(train, config=new_config)(ref_info):
                ref_result.append((cost, error))

            new_config.on_iter = None
            new_config.start_iteration = half_iter
            new_config.end_iteration = a_few_iter
            new_config.impl = "new"
            result = []
            nn.clear_parameters()
            info = load.load(nnp_file,
                             batch_size=batch_size,
                             exclude_parameter=True)
            print("load.load")

            # Here, `info` is different `config`, but optimizer is same.
            load_train_state(saved_parameter_nnp, info)

            for cost, error in partial(train, config=new_config)(info):
                result.append((cost, error))

            compare_info(ref_info, info)

            for i, ((cost_ref, error_ref),
                    (cost, error)) in enumerate(zip(ref_result, result)):
                if verbose:
                    print("{}: cost: {} <--> {}".format(i, cost_ref, cost))
                    print("{}: error: {} <--> {}".format(i, error_ref, error))
                if i > new_config.start_iteration:
                    assert_allclose(np.array([cost_ref, error_ref]),
                                    np.array([cost, error]),
                                    rtol=1e-2,
                                    atol=1e-5,
                                    err_msg="Error: {}".format(nntxt_idx))