Exemple #1
0
def _train_impl(args, model_specs, logger):
    if len(args.output) > 0:
        _make_dirs(args.output)
    # dataiter
    dataset_specs = get_dataset_specs(args, model_specs)
    scale, mean_, _ = _get_scalemeanstd()
    if scale > 0:
        mean_ /= scale
    margs = argparse.Namespace(**model_specs)
    dargs = argparse.Namespace(**dataset_specs)
    dataiter = FileIter(
        dataset=margs.dataset,
        split=args.split,
        data_root=args.data_root,
        sampler='random',
        batch_images=args.batch_images,
        meta=dataset_specs,
        rgb_mean=mean_,
        feat_stride=margs.feat_stride,
        label_stride=margs.feat_stride,
        origin_size=args.origin_size,
        crop_size=args.crop_size,
        scale_rate_range=[float(_) for _ in args.scale_rate_range.split(',')],
        transformer=None,
        transformer_image=ts.Compose(_get_transformer_image()),
        prefetch_threads=args.prefetch_threads,
        prefetcher_type=args.prefetcher,
    )
    dataiter.reset()
    # optimizer
    assert args.to_epoch is not None
    if args.stop_epoch is not None:
        assert args.stop_epoch > args.from_epoch and args.stop_epoch <= args.to_epoch
    else:
        args.stop_epoch = args.to_epoch
    from_iter = args.from_epoch * dataiter.batches_per_epoch
    to_iter = args.to_epoch * dataiter.batches_per_epoch
    lr_params = model_specs['lr_params']
    base_lr = lr_params['base']
    if lr_params['type'] == 'fixed':
        scheduler = FixedScheduler()
    elif lr_params['type'] == 'step':
        left_step = []
        for step in lr_params['args']['step']:
            if from_iter > step:
                base_lr *= lr_params['args']['factor']
                continue
            left_step.append(step - from_iter)
        model_specs['lr_params']['step'] = left_step
        scheduler = mx.lr_scheduler.MultiFactorScheduler(**lr_params['args'])
    elif lr_params['type'] == 'linear':
        scheduler = LinearScheduler(updates=to_iter + 1,
                                    frequency=50,
                                    stop_lr=min(base_lr / 100., 1e-6),
                                    offset=from_iter)
    optimizer_params = {
        'learning_rate': base_lr,
        'momentum': 0.9,
        'wd': args.weight_decay,
        'lr_scheduler': scheduler,
        'rescale_grad': 1.0 / len(args.gpus.split(',')),
    }
    # initializer
    net_args = None
    net_auxs = None
    if args.weights is not None:
        net_args, net_auxs = mxutil.load_params_from_file(args.weights)
    initializer = mx.init.Xavier(rnd_type='gaussian',
                                 factor_type='in',
                                 magnitude=2)
    #
    to_model = osp.join(args.output, '{}_ep'.format(args.model))
    mod = _get_module(args, margs, dargs)
    mod.fit(
        dataiter,
        eval_metric=_get_metric(),
        batch_end_callback=mx.callback.Speedometer(dataiter.batch_size, 1),
        epoch_end_callback=mx.callback.do_checkpoint(to_model),
        kvstore=args.kvstore,
        optimizer='sgd',
        optimizer_params=optimizer_params,
        initializer=initializer,
        arg_params=net_args,
        aux_params=net_auxs,
        allow_missing=args.from_epoch == 0,
        begin_epoch=args.from_epoch,
        num_epoch=args.stop_epoch,
    )
Exemple #2
0
def _train_impl(args, model_specs, logger):
    if len(args.output) > 0:
        _make_dirs(args.output)
    # dataiter
    dataset_specs_tgt = get_dataset_specs_tgt(args, model_specs)
    scale, mean_, _ = _get_scalemeanstd()
    if scale > 0:
        mean_ /= scale
    margs = argparse.Namespace(**model_specs)
    dargs = argparse.Namespace(**dataset_specs_tgt)
    # number of list_lines
    split_filename = 'issegm/data_list/{}/{}.lst'.format(
        margs.dataset, args.split)
    num_source = 0
    with open(split_filename) as f:
        for item in f.readlines():
            num_source = num_source + 1
    #
    batches_per_epoch = num_source // args.batch_images
    # optimizer
    assert args.to_epoch is not None
    if args.stop_epoch is not None:
        assert args.stop_epoch > args.from_epoch and args.stop_epoch <= args.to_epoch
    else:
        args.stop_epoch = args.to_epoch

    from_iter = args.from_epoch * batches_per_epoch
    to_iter = args.to_epoch * batches_per_epoch
    lr_params = model_specs['lr_params']
    base_lr = lr_params['base']
    if lr_params['type'] == 'fixed':
        scheduler = FixedScheduler()
    elif lr_params['type'] == 'step':
        left_step = []
        for step in lr_params['args']['step']:
            if from_iter > step:
                base_lr *= lr_params['args']['factor']
                continue
            left_step.append(step - from_iter)
        model_specs['lr_params']['step'] = left_step
        scheduler = mx.lr_scheduler.MultiFactorScheduler(**lr_params['args'])
    elif lr_params['type'] == 'linear':
        scheduler = LinearScheduler(updates=to_iter + 1,
                                    frequency=50,
                                    stop_lr=min(base_lr / 100., 1e-6),
                                    offset=from_iter)
    elif lr_params['type'] == 'poly':
        scheduler = PolyScheduler(updates=to_iter + 1,
                                  frequency=50,
                                  stop_lr=min(base_lr / 100., 1e-8),
                                  power=0.9,
                                  offset=from_iter)

    initializer = mx.init.Xavier(rnd_type='gaussian',
                                 factor_type='in',
                                 magnitude=2)
    optimizer_params = {
        'learning_rate': base_lr,
        'momentum': 0.9,
        'wd': args.weight_decay,
        'lr_scheduler': scheduler,
        'rescale_grad': 1.0 / len(args.gpus.split(',')),
    }

    data_src_port = args.init_src_port
    data_src_num = int(num_source * data_src_port)
    mod = _get_module(args, margs, dargs)
    addr_weights = args.weights  # first weights should be xxxx_ep-0000.params!
    addr_output = args.output

    # initializer
    net_args = None
    net_auxs = None
    ###
    if addr_weights is not None:
        net_args, net_auxs = mxutil.load_params_from_file(addr_weights)

    ####################################### training model
    to_model = osp.join(addr_output, str(args.idx_round),
                        '{}_ep'.format(args.model))
    dataiter = FileIter(
        dataset=margs.dataset,
        split=args.split,
        data_root=args.data_root,
        num_sel_source=data_src_num,
        num_source=num_source,
        seed_int=args.seed_int,
        dataset_tgt=args.dataset_tgt,
        split_tgt=args.split_tgt,
        data_root_tgt=args.data_root_tgt,
        sampler='random',
        batch_images=args.batch_images,
        meta=dataset_specs_tgt,
        rgb_mean=mean_,
        feat_stride=margs.feat_stride,
        label_stride=margs.feat_stride,
        origin_size=args.origin_size,
        origin_size_tgt=args.origin_size_tgt,
        crop_size=args.crop_size,
        scale_rate_range=[float(_) for _ in args.scale_rate_range.split(',')],
        transformer=None,
        transformer_image=ts.Compose(_get_transformer_image()),
        prefetch_threads=args.prefetch_threads,
        prefetcher_type=args.prefetcher,
    )
    dataiter.reset()
    mod.fit(
        dataiter,
        eval_metric=_get_metric(),
        batch_end_callback=mx.callback.log_train_metric(10, auto_reset=False),
        epoch_end_callback=mx.callback.do_checkpoint(to_model),
        kvstore=args.kvstore,
        optimizer='sgd',
        optimizer_params=optimizer_params,
        initializer=initializer,
        arg_params=net_args,
        aux_params=net_auxs,
        allow_missing=args.from_epoch == 0,
        begin_epoch=args.from_epoch,
        num_epoch=args.stop_epoch,
    )
Exemple #3
0
def _val_impl(args, model_specs, logger):
    assert args.prefetch_threads == 1
    assert args.weights is not None
    net_args, net_auxs = util.load_params_from_file(args.weights)
    mod = _get_module(model_specs)
    has_gt = args.split in (
        'train',
        'val',
    )
    scale_, mean_, std_ = _get_scalemeanstd()
    if args.test_scales is None:
        crop_sizes = [model_specs['crop_size']]
    else:
        crop_sizes = sorted([int(_)
                             for _ in args.test_scales.split(',')])[::-1]

    batch_images = args.batch_images

    if has_gt:
        gt_labels = np.array(
            parse_split_file(model_specs['split_filename'], args.data_root)[1])
    save_dir = os.path.join(args.output, os.path.splitext(args.log_file)[0])
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    preds = []
    for crop_size in crop_sizes:
        save_path = os.path.join(save_dir, 'preds_sz{}'.format(crop_size))
        if os.path.isfile(save_path):
            logger.info('File %s exists, skipped crop size %d', save_path,
                        crop_size)
            with open(save_path) as f:
                preds.append(cPickle.load(f))
            continue
        ts_list = [
            ts.Scale(crop_size),
            ts.ThreeCrops(crop_size)
            if args.test_3crops else ts.CenterCrop(crop_size),
        ]
        if scale_ > 0:
            ts_list.append(ts.ListInput(ts.ColorScale(np.single(scale_))))
        ts_list += [ts.ListInput(ts.ColorNormalize(mean_, std_))]
        transformer = ts.Compose(ts_list)
        dataiter = FileIter(
            split_filename=model_specs['split_filename'],
            data_root=args.data_root,
            has_gt=has_gt,
            batch_images=batch_images,
            transformer=transformer,
            prefetch_threads=args.prefetch_threads,
            prefetcher_type=args.prefetcher,
        )
        dataiter.reset()
        mod.bind(dataiter.provide_data,
                 dataiter.provide_label,
                 for_training=False,
                 force_rebind=True)
        if not mod.params_initialized:
            mod.init_params(arg_params=net_args, aux_params=net_auxs)
        this_call_preds = []
        start = time.time()
        counter = [0, 0]
        for nbatch, batch in enumerate(dataiter):
            mod.forward(batch, is_train=False)
            outputs = mod.get_outputs()[0].asnumpy()
            outputs = outputs.reshape(
                (batch_images, -1, model_specs['classes'])).mean(1)
            this_call_preds.append(outputs)
            if args.test_flipping:
                batch.data[0] = mx.nd.flip(batch.data[0], axis=3)
                mod.forward(batch, is_train=False)
                outputs = mod.get_outputs()[0].asnumpy()
                outputs = outputs.reshape(
                    (batch_images, -1, model_specs['classes'])).mean(1)
                this_call_preds[-1] = (this_call_preds[-1] + outputs) / 2
            score_str = ''
            if has_gt:
                counter[0] += batch_images
                counter[1] += (this_call_preds[-1].argmax(1) ==
                               gt_labels[nbatch * batch_images:(nbatch + 1) *
                                         batch_images]).sum()
                score_str = ', Top1 {:.4f}%'.format(100.0 * counter[1] /
                                                    counter[0])
            logger.info('Crop size {}, done {}/{} at speed: {:.2f}/s{}'.\
                format(crop_size, nbatch+1, dataiter.batches_per_epoch, 1.*(nbatch+1)*batch_images / (time.time()-start), score_str))
        logger.info('Done crop size {} in {:.4f}s.'.format(
            crop_size,
            time.time() - start))
        this_call_preds = np.vstack(this_call_preds)
        with open(save_path, 'wb') as f:
            cPickle.dump(this_call_preds, f)
        preds.append(this_call_preds)
    for num_sizes in set((
            1,
            len(crop_sizes),
    )):
        for this_pred_inds in itertools.combinations(xrange(len(crop_sizes)),
                                                     num_sizes):
            this_pred = np.mean([preds[_] for _ in this_pred_inds], axis=0)
            this_pred_label = this_pred.argsort(1)[:, -1 - np.arange(5)]
            logger.info('Done testing crop_size %s',
                        [crop_sizes[_] for _ in this_pred_inds])
            if has_gt:
                top1 = 100. * (this_pred_label[:, 0]
                               == gt_labels).sum() / gt_labels.size
                top5 = 100. * sum(
                    map(lambda x, y: y in x.tolist(), this_pred_label,
                        gt_labels)) / gt_labels.size
                logger.info('Top1 %.4f%%, Top5 %.4f%%', top1, top5)
            else:
                # TODO: Save predictions for submission
                raise NotImplementedError('Save predictions for submission')
Exemple #4
0
def _train_impl(args, model_specs, logger):
    if len(args.output) > 0:
        _make_dirs(args.output)
    # dataiter
    dataset_specs = get_dataset_specs(args, model_specs)
    scale, mean_, _ = _get_scalemeanstd()
    if scale > 0:
        mean_ /= scale
    margs = argparse.Namespace(**model_specs)
    dargs = argparse.Namespace(**dataset_specs)
    dataiter = FileIter(dataset=margs.dataset,
                        split=args.split,
                        data_root=args.data_root,
                        sampler='random',
                        batch_images=args.batch_images,
                        meta=dataset_specs,
                        rgb_mean=mean_,
                        feat_stride=margs.feat_stride,
                        label_stride=margs.feat_stride,
                        origin_size=args.origin_size,
                        crop_size=args.crop_size,
                        scale_rate_range=[float(_) for _ in args.scale_rate_range.split(',')],
                        transformer=None,
                        transformer_image=ts.Compose(_get_transformer_image()),
                        prefetch_threads=args.prefetch_threads,
                        prefetcher_type=args.prefetcher,)
    dataiter.reset()
    # optimizer
    assert args.to_epoch is not None
    if args.stop_epoch is not None:
        assert args.stop_epoch > args.from_epoch and args.stop_epoch <= args.to_epoch
    else:
        args.stop_epoch = args.to_epoch
    from_iter = args.from_epoch * dataiter.batches_per_epoch
    to_iter = args.to_epoch * dataiter.batches_per_epoch
    lr_params = model_specs['lr_params']
    base_lr = lr_params['base']
    if lr_params['type'] == 'fixed':
        scheduler = FixedScheduler()
    elif lr_params['type'] == 'step':
        left_step = []
        for step in lr_params['args']['step']:
            if from_iter > step:
                base_lr *= lr_params['args']['factor']
                continue
            left_step.append(step - from_iter)
        model_specs['lr_params']['step'] = left_step
        scheduler = mx.lr_scheduler.MultiFactorScheduler(**lr_params['args'])
    elif lr_params['type'] == 'linear':
        scheduler = LinearScheduler(updates=to_iter+1, frequency=50,
                                    stop_lr=min(base_lr/100., 1e-6),
                                    offset=from_iter)
    optimizer_params = {
        'learning_rate': base_lr,
        'momentum': 0.9,
        'wd': args.weight_decay,
        'lr_scheduler': scheduler,
        'rescale_grad': 1.0/len(args.gpus.split(',')),
    }
    # initializer
    net_args = None
    net_auxs = None
    if args.weights is not None:
        net_args, net_auxs = mxutil.load_params_from_file(args.weights)
    initializer = mx.init.Xavier(rnd_type='gaussian', factor_type='in', magnitude=2)
    #
    to_model = osp.join(args.output, '{}_ep'.format(args.model))
    mod = _get_module(args, margs, dargs)
    mod.fit(
        dataiter,
        eval_metric=_get_metric(),
        batch_end_callback=mx.callback.Speedometer(dataiter.batch_size, 1),
        epoch_end_callback=mx.callback.do_checkpoint(to_model),
        kvstore=args.kvstore,
        optimizer='sgd',
        optimizer_params=optimizer_params,
        initializer=initializer,
        arg_params=net_args,
        aux_params=net_auxs,
        allow_missing=args.from_epoch == 0,
        begin_epoch=args.from_epoch,
        num_epoch=args.stop_epoch,
    )
Exemple #5
0
    ctx = mx.gpu(1)

    # =============module network=============
    mod_network = mx.mod.Module(symbol=network, data_names=('data',), label_names=('l2_label',), context=ctx)
    mod_network.bind(data_shapes=dataiter.provide_data, label_shapes=dataiter.provide_label)
    mod_network.init_params(initializer=mx.init.Normal(0.02))
    mod_network.init_optimizer(
        optimizer='adam',
        optimizer_params={
            'learning_rate': lr,
            'wd': 0.,
            'beta1': beta1,
        })

    # create eval_metrix
    eval_metric = mx.metric.create('rmse')

    data_name = dataiter.data_name
    label_name = dataiter.label_name

    for epoch in range(10000):
        dataiter.reset()
        for t, batch in enumerate(dataiter):
            mod_network.forward(batch, is_train=True)
            mod_network.backward()
            mod_network.update()
            mod_network.update_metric(eval_metric, batch.label)

            print('epoch:', epoch, 'iter:', t, 'metric:', eval_metric.get())

        mod_network.save_params('model_%04d.params' % epoch)
Exemple #6
0
def _val_impl(args, model_specs, logger):
    assert args.prefetch_threads == 1
    assert args.weights is not None
    net_args, net_auxs = mxutil.load_params_from_file(args.weights)
    mod = _get_module(args, model_specs)
    has_gt = args.split in ('train', 'val',)
    scale_, mean_, std_ = _get_scalemeanstd()
    if args.test_scales is None:
        crop_sizes = [model_specs['crop_size']]
    else:
        crop_sizes = sorted([int(_) for _ in args.test_scales.split(',')])[::-1]
    
    batch_images = args.batch_images
    
    if has_gt:
        gt_labels = np.array(parse_split_file(model_specs['dataset'], args.split, args.data_root)[1])
    save_dir = os.path.join(args.output, os.path.splitext(args.log_file)[0])
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    preds = []
    for crop_size in crop_sizes:
        save_path = os.path.join(save_dir, 'preds_sz{}'.format(crop_size))
        if os.path.isfile(save_path):
            logger.info('File %s exists, skipped crop size %d', save_path, crop_size)
            with open(save_path) as f:
                preds.append(cPickle.load(f))
            continue
        ts_list = [ts.Scale(crop_size),
                   ts.ThreeCrops(crop_size) if args.test_3crops else ts.CenterCrop(crop_size),]
        if scale_ > 0:
            ts_list.append(ts.ListInput(ts.ColorScale(np.single(scale_))))
        ts_list += [ts.ListInput(ts.ColorNormalize(mean_, std_))]
        transformer = ts.Compose(ts_list)
        dataiter = FileIter(dataset=model_specs['dataset'],
                            split=args.split,
                            data_root=args.data_root,
                            sampler='fixed',
                            has_gt=has_gt,
                            batch_images=batch_images,
                            transformer=transformer,
                            prefetch_threads=args.prefetch_threads,
                            prefetcher_type=args.prefetcher,)
        dataiter.reset()
        mod.bind(dataiter.provide_data, dataiter.provide_label, for_training=False, force_rebind=True)
        if not mod.params_initialized:
            mod.init_params(arg_params=net_args, aux_params=net_auxs)
        this_call_preds = []
        start = time.time()
        counter = [0, 0]
        for nbatch, batch in enumerate(dataiter):
            mod.forward(batch, is_train=False)
            outputs = mod.get_outputs()[0].asnumpy()
            outputs = outputs.reshape((batch_images, -1, model_specs['classes'])).mean(1)
            this_call_preds.append(outputs)
            if args.test_flipping:
                batch.data[0] = mx.nd.flip(batch.data[0], axis=3)
                mod.forward(batch, is_train=False)
                outputs = mod.get_outputs()[0].asnumpy()
                outputs = outputs.reshape((batch_images, -1, model_specs['classes'])).mean(1)
                this_call_preds[-1] = (this_call_preds[-1] + outputs) / 2
            score_str = ''
            if has_gt:
                counter[0] += batch_images
                counter[1] += (this_call_preds[-1].argmax(1) == gt_labels[nbatch*batch_images : (nbatch+1)*batch_images]).sum()
                score_str = ', Top1 {:.4f}%'.format(100.0*counter[1] / counter[0])
            logger.info('Crop size {}, done {}/{} at speed: {:.2f}/s{}'.\
                format(crop_size, nbatch+1, dataiter.batches_per_epoch, 1.*(nbatch+1)*batch_images / (time.time()-start), score_str))
        logger.info('Done crop size {} in {:.4f}s.'.format(crop_size, time.time() - start))
        this_call_preds = np.vstack(this_call_preds)
        with open(save_path, 'wb') as f:
            cPickle.dump(this_call_preds, f)
        preds.append(this_call_preds)
    for num_sizes in set((1, len(crop_sizes),)):
        for this_pred_inds in itertools.combinations(xrange(len(crop_sizes)), num_sizes):
            this_pred = np.mean([preds[_] for _ in this_pred_inds], axis=0)
            this_pred_label = this_pred.argsort(1)[:, -1 - np.arange(5)]
            logger.info('Done testing crop_size %s', [crop_sizes[_] for _ in this_pred_inds])
            if has_gt:
                top1 = 100. * (this_pred_label[:, 0] == gt_labels).sum() / gt_labels.size
                top5 = 100. * sum(map(lambda x, y: y in x.tolist(), this_pred_label, gt_labels)) / gt_labels.size
                logger.info('Top1 %.4f%%, Top5 %.4f%%', top1, top5)
            else:
                # TODO: Save predictions for submission
                raise NotImplementedError('Save predictions for submission')
Exemple #7
0
def _train_impl(args, model_specs, logger):
    # dataiter
    scale_, mean_, std_ = _get_scalemeanstd()
    assert scale_ == 1./255
    pca = (np.array([0.2175, 0.0188, 0.0045]),
           np.array([[-0.5675, 0.7192, 0.4009],
                     [-0.5808, -0.0045, -0.814],
                     [-0.5836, -0.6948, 0.4203]]))
    crop_size = model_specs['crop_size']
    transformer = ts.Compose([ts.RandomSizedCrop(crop_size),
                              ts.ColorScale(np.single(1./255)),
                              ts.ColorJitter(crop_size, 0.4, 0.4, 0.4),
                              ts.Lighting(0.1, pca[0], pca[1]),
                              ts.Bound(lower=0., upper=1.),
                              ts.HorizontalFlip(),
                              ts.ColorNormalize(mean_, std_),])
    if model_specs['dataset'] == 'ilsvrc-cls':
        dataiter = FileIter(dataset=model_specs['dataset'],
                            split=args.split,
                            data_root=args.data_root,
                            sampler='random',
                            batch_images=model_specs['batch_images'],
                            transformer=transformer,
                            prefetch_threads=args.prefetch_threads,
                            prefetcher_type=args.prefetcher,)
    else:
        raise NotImplementedError('Unknown dataset: {}'.format(model_specs['dataset']))
    dataiter.reset()
    # optimizer
    assert args.to_epoch is not None
    if args.stop_epoch is not None:
        assert args.stop_epoch > args.from_epoch and args.stop_epoch <= args.to_epoch
    else:
        args.stop_epoch = args.to_epoch
    from_iter = args.from_epoch * dataiter.batches_per_epoch
    to_iter = args.to_epoch * dataiter.batches_per_epoch
    lr_params = model_specs['lr_params']
    base_lr = lr_params['base']
    if lr_params['type'] == 'fixed':
        scheduler = FixedScheduler()
    elif lr_params['type'] == 'step':
        left_step = []
        for step in lr_params['args']['step']:
            if from_iter > step:
                base_lr *= lr_params['args']['factor']
                continue
            left_step.append(step - from_iter)
        model_specs['lr_params']['step'] = left_step
        scheduler = mx.lr_scheduler.MultiFactorScheduler(**lr_params['args'])
    elif lr_params['type'] == 'linear':
        scheduler = LinearScheduler(updates=to_iter+1, frequency=50,
                                    stop_lr=1e-6, offset=from_iter)
    optimizer_params = {
        'learning_rate': base_lr,
        'momentum': 0.9,
        'wd': 0.0001,
        'lr_scheduler': scheduler,
    }
    # initializer
    net_args = None
    net_auxs = None
    if args.weights is not None:
        net_args, net_auxs = mxutil.load_params_from_file(args.weights)
    initializer = mx.init.Mixed(
        ['linear.*', '.*',],
        [TorchXavier_Linear(rnd_type='uniform', factor_type='in', magnitude=1),
         mx.init.Xavier(rnd_type='gaussian', factor_type='in', magnitude=2),]
    )
    # fit
    to_model = os.path.join(args.output, '{}_ep'.format(args.model))
    mod = _get_module(args, model_specs)
    mod.fit(
        dataiter,
        eval_metric=_get_metric(),
        batch_end_callback=mx.callback.Speedometer(dataiter.batch_size, 1),
        epoch_end_callback=mx.callback.do_checkpoint(to_model),
        kvstore=args.kvstore,
        optimizer='TorchNesterov',
        optimizer_params=optimizer_params,
        initializer=initializer,
        arg_params=net_args,
        aux_params=net_auxs,
        allow_missing=args.from_epoch == 0,
        begin_epoch=args.from_epoch,
        num_epoch=args.stop_epoch,
    )