Пример #1
0
def train(config):

    # Reproductivity
    if config.rseed is not None:
        set_random_seed(config.rseed)
    torch.backends.cudnn.deterministic = config.cudnn_deterministic

    # Initialization on Environment
    if not os.path.exists(config.cache_path):
        os.makedirs(config.cache_path)

    if not os.path.exists(config.model_dir):
        os.makedirs(config.model_dir)
    print('model_dir', config.model_dir)

    with open(config.history_path, 'w') as f:
        f.write('')

    # Tokenizer
    print('#', 'initializing textifiers', '...')
    textifier = UnifiedTextifier.load_or_make(
        config.textifier_dict,
        config.dataset_list['train'],
        config.textifier_use_more_than,
    )
    print('textifier vocab size', textifier.get_len())

    # Dataset & Data Loader
    print('#', 'loading datasets', '...')
    datasets = {}
    for key, dataset_list in config.dataset_list.items():
        is_train = key == 'train'
        datasets[key] = dataset = VisualSelectionDataset(
            dataset_list=dataset_list,
            textifier=textifier,
            provide_image=config.provide_image,
            image_shape=config.model_net_args['image_shape'],
            dim_object_optional_info=config.
            model_net_args['dim_object_optional_info'],
            ratio_force_unk=config.train_textifier_ratio_force_unk
            if is_train else 0,
            ratio_force_zero=config.train_optional_info_ratio_force_zero
            if is_train else 0,
        )
        print('len dataset', key, len(dataset))

    # Model
    # Some modifications will be required to use multi GPU
    print('#', 'constructing a model', '...')
    model_net = SelectorNet(textifier.get_len(), **config.model_net_args)
    model_net.to(config.device)
    model_net.device = config.device

    # Optimizer
    opt = getattr(torch.optim,
                  config.optimizer_name)(filter(lambda x: x.requires_grad,
                                                model_net.parameters()),
                                         **config.optimizer_args)
    scheduler = None
    if config.optimizer_scheduler is not None:
        s_args = config.optimizer_scheduler.copy()
        name = s_args.pop('name')
        scheduler = getattr(torch.optim.lr_scheduler, name)(opt, **s_args)

    # Training loop
    print('#', 'training loop starts')
    print('n_epoch', config.n_epoch)
    for i_epoch in range(config.n_epoch):
        # a setter for worker random number generator's seed
        def worker_init_fn(_id):
            if config.rseed is not None:
                seed = config.rseed + (config.n_workers + 1) * i_epoch + _id
                random.seed(seed)
                np.random.seed(seed)

        # training
        data_loader = DataLoader(
            datasets['train'],
            batch_size=config.minibatch_size,
            num_workers=config.n_workers,
            collate_fn=collate_fn,
            shuffle=True,
            worker_init_fn=worker_init_fn,
        )
        model_net.train()
        train_summary = run_dataset(config, model_net, opt, data_loader,
                                    i_epoch)
        if scheduler is not None:
            scheduler.step()

        # validation
        data_loader = DataLoader(
            datasets['valid'],
            batch_size=config.minibatch_size,
            num_workers=config.n_workers,
            collate_fn=collate_fn,
            shuffle=False,
        )
        model_net.eval()
        with torch.no_grad():
            valid_summary = run_dataset(config, model_net, opt, data_loader,
                                        i_epoch)

        # Save states
        weight_path = config.weight_path_template % (i_epoch)
        checkpoint = {
            'model_net_state_dict': model_net.state_dict(),
        }
        torch.save(checkpoint, weight_path)

        # Save history
        with open(config.history_path, 'a') as f:
            f.write(' '.join([
                str(i_epoch),
                train_summary.to_str('loss', 'acc', no_name=True),
                valid_summary.to_str('loss', 'acc', no_name=True),
            ]) + '\n')

        print(' '.join([
            'ep=%d' % (i_epoch),
            train_summary.to_str(prefix='t_'),
            valid_summary.to_str(prefix='v_'),
        ]))
Пример #2
0
def eval_model(net_module):
    
    print('#', 'Evaluation')
    
    config = net_module.TrainConfig()
    weight_path = _get_best_weight_path(config)
    print('model_dir', config.model_dir)
    print('best_weight_path', weight_path)
    
    # Tokenizer
    print('#', 'initializing textifiers')
    textifier = UnifiedTextifier.load_or_make(config.textifier_dict,
        config.dataset_list['train'], config.textifier_use_more_than,
    )
    print('textifier vocab size', len(textifier))
    
    # Dataset
    print('#', 'loading dataset definitions')
    datasets = []
    for key, dataset_spec in TARGET_DATASETS:
        dataset = VisualSelectionDataset(
            dataset_list=[dataset_spec],
            textifier=textifier,
            provide_image=config.provide_image,
            image_size=config.model_net_args['image_size'],
            dim_object_optional_info=config.model_net_args['dim_object_optional_info'],
            ratio_force_unk=0, ratio_force_zero=0,
        )
        datasets.append((key, dataset))
        print('len', key, len(dataset))
        
    # Model
    print('#', 'constructing a model')
    model_net = net_module.SelectorNet(len(textifier), **config.model_net_args)
    model_net.to(config.device)
    model_net.device = config.device
    # load
    states = torch.load(weight_path, map_location=torch.device('cpu'))
    model_net.load_state_dict(states['model_net_state_dict'])
   
    # start evaluation
    summaries = []
    for name, dataset in datasets:
        
        print('+', name, 'evaluating...')
        
        data_loader = DataLoader(
            dataset, 
            batch_size=config.minibatch_size, 
            num_workers=config.n_workers, 
            collate_fn=collate_fn, 
            shuffle=False,
        )
        
        model_net.eval()
        with torch.no_grad():
            summary = net_module.run_dataset(config, model_net, None, data_loader, 0)
        
        summaries.append(' '.join([summary.to_str(prefix=name+'_')]))
        print(summaries[-1])

    print('#', 'results')
    print('\n'.join(summaries))