def train(config): # Reproductivity if config.rseed is not None: set_random_seed(config.rseed) torch.backends.cudnn.deterministic = config.cudnn_deterministic # Initialization on Environment if not os.path.exists(config.cache_path): os.makedirs(config.cache_path) if not os.path.exists(config.model_dir): os.makedirs(config.model_dir) print('model_dir', config.model_dir) with open(config.history_path, 'w') as f: f.write('') # Tokenizer print('#', 'initializing textifiers', '...') textifier = UnifiedTextifier.load_or_make( config.textifier_dict, config.dataset_list['train'], config.textifier_use_more_than, ) print('textifier vocab size', textifier.get_len()) # Dataset & Data Loader print('#', 'loading datasets', '...') datasets = {} for key, dataset_list in config.dataset_list.items(): is_train = key == 'train' datasets[key] = dataset = VisualSelectionDataset( dataset_list=dataset_list, textifier=textifier, provide_image=config.provide_image, image_shape=config.model_net_args['image_shape'], dim_object_optional_info=config. model_net_args['dim_object_optional_info'], ratio_force_unk=config.train_textifier_ratio_force_unk if is_train else 0, ratio_force_zero=config.train_optional_info_ratio_force_zero if is_train else 0, ) print('len dataset', key, len(dataset)) # Model # Some modifications will be required to use multi GPU print('#', 'constructing a model', '...') model_net = SelectorNet(textifier.get_len(), **config.model_net_args) model_net.to(config.device) model_net.device = config.device # Optimizer opt = getattr(torch.optim, config.optimizer_name)(filter(lambda x: x.requires_grad, model_net.parameters()), **config.optimizer_args) scheduler = None if config.optimizer_scheduler is not None: s_args = config.optimizer_scheduler.copy() name = s_args.pop('name') scheduler = getattr(torch.optim.lr_scheduler, name)(opt, **s_args) # Training loop print('#', 'training loop starts') print('n_epoch', config.n_epoch) for i_epoch in range(config.n_epoch): # a setter for worker random number generator's seed def worker_init_fn(_id): if config.rseed is not None: seed = config.rseed + (config.n_workers + 1) * i_epoch + _id random.seed(seed) np.random.seed(seed) # training data_loader = DataLoader( datasets['train'], batch_size=config.minibatch_size, num_workers=config.n_workers, collate_fn=collate_fn, shuffle=True, worker_init_fn=worker_init_fn, ) model_net.train() train_summary = run_dataset(config, model_net, opt, data_loader, i_epoch) if scheduler is not None: scheduler.step() # validation data_loader = DataLoader( datasets['valid'], batch_size=config.minibatch_size, num_workers=config.n_workers, collate_fn=collate_fn, shuffle=False, ) model_net.eval() with torch.no_grad(): valid_summary = run_dataset(config, model_net, opt, data_loader, i_epoch) # Save states weight_path = config.weight_path_template % (i_epoch) checkpoint = { 'model_net_state_dict': model_net.state_dict(), } torch.save(checkpoint, weight_path) # Save history with open(config.history_path, 'a') as f: f.write(' '.join([ str(i_epoch), train_summary.to_str('loss', 'acc', no_name=True), valid_summary.to_str('loss', 'acc', no_name=True), ]) + '\n') print(' '.join([ 'ep=%d' % (i_epoch), train_summary.to_str(prefix='t_'), valid_summary.to_str(prefix='v_'), ]))
def eval_model(net_module): print('#', 'Evaluation') config = net_module.TrainConfig() weight_path = _get_best_weight_path(config) print('model_dir', config.model_dir) print('best_weight_path', weight_path) # Tokenizer print('#', 'initializing textifiers') textifier = UnifiedTextifier.load_or_make(config.textifier_dict, config.dataset_list['train'], config.textifier_use_more_than, ) print('textifier vocab size', len(textifier)) # Dataset print('#', 'loading dataset definitions') datasets = [] for key, dataset_spec in TARGET_DATASETS: dataset = VisualSelectionDataset( dataset_list=[dataset_spec], textifier=textifier, provide_image=config.provide_image, image_size=config.model_net_args['image_size'], dim_object_optional_info=config.model_net_args['dim_object_optional_info'], ratio_force_unk=0, ratio_force_zero=0, ) datasets.append((key, dataset)) print('len', key, len(dataset)) # Model print('#', 'constructing a model') model_net = net_module.SelectorNet(len(textifier), **config.model_net_args) model_net.to(config.device) model_net.device = config.device # load states = torch.load(weight_path, map_location=torch.device('cpu')) model_net.load_state_dict(states['model_net_state_dict']) # start evaluation summaries = [] for name, dataset in datasets: print('+', name, 'evaluating...') data_loader = DataLoader( dataset, batch_size=config.minibatch_size, num_workers=config.n_workers, collate_fn=collate_fn, shuffle=False, ) model_net.eval() with torch.no_grad(): summary = net_module.run_dataset(config, model_net, None, data_loader, 0) summaries.append(' '.join([summary.to_str(prefix=name+'_')])) print(summaries[-1]) print('#', 'results') print('\n'.join(summaries))