def get_model_and_data_loaders( config: ConfigParser, logger: logging.Logger, ckpt_path: Path, ) -> Tuple[torch.nn.Module, module_data.ExpertDataLoader]: expert_dims, raw_input_dims, text_dim = compute_dims(config) data_loaders = config.init( name='data_loader', module=module_data, logger=logger, raw_input_dims=raw_input_dims, challenge_mode=config.get("challenge_mode", False), text_dim=text_dim, text_feat=config["experts"]["text_feat"], text_agg=config["experts"]["text_agg"], use_zeros_for_missing=config["experts"].get("use_zeros_for_missing", False), task=config.get("task", "retrieval"), eval_only=True, distil_params=config.get("distil_params", None), training_file=config.get("training_file", None), caption_masks=config.get("caption_masks", None), ce_shared_dim=config["experts"].get("ce_shared_dim", None), ) trn_config = compute_trn_config(config) model = config.init( name='arch', module=module_arch, trn_config=trn_config, expert_dims=expert_dims, text_dim=text_dim, disable_nan_checks=config["disable_nan_checks"], task=config.get("task", "retrieval"), ce_shared_dim=config["experts"].get("ce_shared_dim", None), feat_aggregation=config["data_loader"]["args"]["feat_aggregation"], trn_cat=config["data_loader"]["args"].get("trn_cat", 0), ) ckpt_path = config._args.resume logger.info(f"Loading checkpoint: {ckpt_path} ...") checkpoint = torch.load(ckpt_path) state_dict = checkpoint['state_dict'] if config['n_gpu'] > 1: model = torch.nn.DataParallel(model) # support backwards compatibility deprecated = ["ce.moe_fc_bottleneck1", "ce.moe_cg", "ce.moe_fc_proj"] for mod in deprecated: for suffix in ("weight", "bias"): key = f"{mod}.{suffix}" if key in state_dict: print(f"WARNING: Removing deprecated key {key} from model") state_dict.pop(key) model.load_state_dict(state_dict) return model, data_loaders
def get_model_and_data_loaders( config: ConfigParser, logger: logging.Logger, ckpt_path: Path, ) -> Tuple[torch.nn.Module, module_data.ExpertDataLoader]: expert_dims, raw_input_dims = compute_dims(config) trn_config = compute_trn_config(config) data_loaders = config.init( name='data_loader', module=module_data, logger=logger, raw_input_dims=raw_input_dims, challenge_mode=config.get("challenge_mode", False), text_feat=config["experts"]["text_feat"], text_dim=config["experts"]["text_dim"], text_agg=config["experts"]["text_agg"], use_zeros_for_missing=config["experts"].get("use_zeros_for_missing", False), task=config.get("task", "retrieval"), eval_only=True, ) model = config.init( name='arch', module=module_arch, trn_config=trn_config, expert_dims=expert_dims, text_dim=config["experts"]["text_dim"], disable_nan_checks=config["disable_nan_checks"], task=config.get("task", "retrieval"), ce_shared_dim=config["experts"].get("ce_shared_dim", None), feat_aggregation=config["data_loader"]["args"]["feat_aggregation"], trn_cat=config["data_loader"]["args"].get("trn_cat", 0), ) ckpt_path = config._args.resume logger.info(f"Loading checkpoint: {ckpt_path} ...") checkpoint = torch.load(ckpt_path) state_dict = checkpoint['state_dict'] if config['n_gpu'] > 1: model = torch.nn.DataParallel(model) model.load_state_dict(state_dict) return model, data_loaders
def evaluation(config, logger=None, trainer=None): if logger is None: logger = config.get_logger('test') if getattr(config._args, "eval_from_training_config", False): eval_conf = copy.deepcopy(config) merge(eval_conf._config, config["eval_settings"], strategy=Strategy.REPLACE) config = eval_conf logger.info("Running evaluation with configuration:") logger.info(config) expert_dims, raw_input_dims = compute_dims(config) trn_config = compute_trn_config(config) # Set the random initial seeds seed = config["seed"] logger.info(f"Setting experiment random seed to {seed}") random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) update_src_web_video_dir(config) visualizer = config.init( name='visualizer', module=module_vis, exp_name=config._exper_name, web_dir=config._web_log_dir, ) data_loaders = config.init( name='data_loader', module=module_data, logger=logger, raw_input_dims=raw_input_dims, challenge_mode=config.get("challenge_mode", False), text_feat=config["experts"]["text_feat"], text_dim=config["experts"]["text_dim"], text_agg=config["experts"]["text_agg"], use_zeros_for_missing=config["experts"].get("use_zeros_for_missing", False), task=config.get("task", "retrieval"), eval_only=True, ) model = config.init( name='arch', module=module_arch, trn_config=trn_config, expert_dims=expert_dims, text_dim=config["experts"]["text_dim"], disable_nan_checks=config["disable_nan_checks"], task=config.get("task", "retrieval"), ce_shared_dim=config["experts"].get("ce_shared_dim", None), feat_aggregation=config["data_loader"]["args"]["feat_aggregation"], trn_cat=config["data_loader"]["args"].get("trn_cat", 0), ) logger.info(model) metrics = [getattr(module_metric, met) for met in config['metrics']] ckpt_path = config._args.resume logger.info(f"Loading checkpoint: {ckpt_path} ...") checkpoint = torch.load(ckpt_path) state_dict = checkpoint['state_dict'] if config['n_gpu'] > 1: model = torch.nn.DataParallel(model) model.load_state_dict(state_dict) challenge_mode = config.get("challenge_mode", False) challenge_msg = ( "\n" "Evaluation ran on challenge features. To obtain a score, upload the similarity" "matrix for each dataset to the test server after running the " "`misc/cvpr2020-challenge/prepare_submission.py` script and following the " "instructions at: " "https://www.robots.ox.ac.uk/~vgg/challenges/video-pentathlon/" "\n") # prepare model for testing. Note that some datasets fail to fit the retrieval # set on the GPU, so we run them on the CPU if torch.cuda.is_available() and not config.get("disable_gpu", True): device = "cuda" else: device = "cpu" logger.info(f"Running evaluation on {device}") model = model.to(device) model.eval() with torch.no_grad(): samples, meta = data_loaders["retrieval"] # To use the nan-checks safely, we need make temporary copies of the data disable_nan_checks = config._config["disable_nan_checks"] with ctxt_mgr(samples, device, disable_nan_checks) as valid: output = model(**valid) sims = output["cross_view_conf_matrix"].data.cpu().float().numpy() dataset = data_loaders.dataset_name if challenge_mode: split = data_loaders.dataloaders["dataset"].split_name prediction_path = config._log_dir / f"{dataset}-{split}-predictions.csv" compressed_preds = compress_predictions( query_masks=meta["query_masks"], sims=sims, ) np.savetxt(prediction_path, compressed_preds, delimiter=',', fmt="%d") print(f"Saved similarity matrix predictions to {prediction_path}") print(challenge_msg) return nested_metrics = {} for metric in metrics: metric_name = metric.__name__ res = metric(sims, query_masks=meta["query_masks"]) verbose(epoch=0, metrics=res, name=dataset, mode=metric_name) if trainer is not None: if not trainer.mini_train: trainer.writer.set_step(step=0, mode="val") # avoid tensboard folding by prefixing metric_name_ = f"test_{metric_name}" trainer.log_metrics(res, metric_name=metric_name_, mode="val") nested_metrics[metric_name] = res if data_loaders.num_test_captions == 1: visualizer.visualize_ranking( sims=sims, meta=meta, epoch=0, nested_metrics=nested_metrics, ) log = {} for subkey, subval in nested_metrics.items(): for subsubkey, subsubval in subval.items(): log[f"test_{subkey}_{subsubkey}"] = subsubval for key, value in log.items(): logger.info(" {:15s}: {}".format(str(key), value))
def test(config): config.config['data_loader']['args']['mode'] = 'test' logger = config.get_logger('test') logger.info("Running test with configuration:") logger.info(config) expert_dims, raw_input_dims = compute_dims(config) if config['experts']['text_feat'] == 'learnable': # vocab vocab = Vocabulary() vocab.load('dataset/captions/dict.all_200k_gan.json') vocab_size = len(vocab) # word2vec if config['experts']['text_feat_init'] == True: # word2vec, download file and move to we_root-path directory # https://www.kaggle.com/jacksoncrow/word2vec-flickr30k/version/1 we_rootpath = '/home/yj/pretrained_model' w2v_data_path = os.path.join(we_rootpath, "word2vec/", 'flickr', 'vec500flickr30m') we_parameter = get_we_parameter(vocab, w2v_data_path) else: we_parameter = None else: vocab = None vocab_size = None we_parameter = None if "attr" in config['experts']['modalities']: attr_vocab = Vocabulary() attr_vocab.load('dataset/captions/dict.attr.json') attr_vocab_size = len(attr_vocab) else: attr_vocab = None attr_vocab_size = None data_loaders = config.init(name='data_loader', module=module_data, raw_input_dims=raw_input_dims, text_feat=config['experts']['text_feat'], text_dim=config['experts']['text_dim'], vocab=vocab, attr_vocab=attr_vocab, pretrain=config['trainer']['pretrain']) model = config.init(name='arch', module=module_arch, expert_dims=expert_dims, text_dim=config['experts']['text_dim'], same_dim=config['experts']['ce_shared_dim'], we_parameter=we_parameter, vocab_size=vocab_size, attr_vocab_size=attr_vocab_size, text_feat=config['experts']['text_feat']) ckpt_path = Path(config._args.resume) logger.info(f"Loading checkpoint: {ckpt_path} ...") checkpoint = torch.load(ckpt_path) state_dict = checkpoint['state_dict'] if config['n_gpu'] > 1: model = torch.nn.DataParallel(model) model.load_state_dict(state_dict) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger.info(f"Running test on {device}") model = model.to(device) model.eval() categories = ['dress', 'shirt', 'toptee'] modalities = data_loaders[categories[0]].dataset.ordered_experts metric = {'score': dict()} for i, category in enumerate(categories): val_experts = {expert: list() for expert in modalities} target_ind = {expert: list() for expert in modalities} data_asin = [] for batch in data_loaders[category + '_trg']: for key, val in batch['candidate_experts'].items(): batch['candidate_experts'][key] = val.to(device) data_asin.extend( [meta['candidate'] for meta in batch['meta_info']]) for key, val in batch['candidate_ind'].items(): target_ind[key].append(val) with torch.no_grad(): experts, _, _ = model(batch['candidate_experts'], batch['candidate_ind'], target=True) for modality, val in experts.items(): val_experts[modality].append(val) for modality, val in val_experts.items(): val_experts[modality] = torch.cat(val) for modality, val in target_ind.items(): target_ind[modality] = torch.cat(val) scores = [] meta_infos = [] val_size = val_experts['resnet'].size(0) for batch in data_loaders[category]: for experts in ['candidate_experts']: for key, val in batch[experts].items(): batch[experts][key] = val.to(device) batch["text"] = batch["text"].to(device) batch_size = batch["text"].size(0) meta_infos.extend(list(batch['meta_info'])) with torch.no_grad(): # composition_feature, text, moe_weights = model(batch['candidate_experts'], # batch['candidate_ind'], # batch['text'], # batch['text_bow'], # batch['text_lengths']) # batch_target = dict() # for mod in modalities: # tmp = [] # for k in range(batch_size): # tmp.append(model.target_composition(val_experts[mod], text[mod][k].expand(val_size, -1))) # batch_target[mod] = torch.stack(tmp) src_experts = model.image_encoder(batch['candidate_experts'], batch['candidate_ind']) src_text, moe_weights = model.get_text_feature( batch['text'], batch['candidate_ind'], batch['text_bow'], batch['text_lengths']) src_feature = model.get_combined_feature(src_experts, src_text) trg_text, _ = model.get_text_feature(batch['text'], batch['target_ind'], batch['text_bow'], batch['text_lengths'], target=True) # trg_text, _ = self.model.text_encoder['trg'](batch['text_mean'].unsqueeze(1), batch['target_ind']) batch_target = dict() for h, mod in enumerate(modalities): tmp = [] for k in range(batch_size): tmp.append( model.trg_normalization_layer( model.target_composition[h]( val_experts[mod], trg_text[mod][k].expand(val_size, -1)))) batch_target[mod] = torch.stack(tmp) cross_view_conf_matrix = sharded_cross_view_inner_product( vid_embds=batch_target, text_embds=src_feature, text_weights=moe_weights, subspaces=model.image_encoder.modalities, l2renorm=True, dist=True, val=True) scores.append(cross_view_conf_matrix) scores = torch.cat(scores) val_ids = data_loaders[category + '_trg'].dataset.data assert val_ids == data_asin metric['score'][category] = { 'ids': val_ids, 'matrix': scores, 'meta_info': meta_infos } save_fname = ckpt_path.parent / f'test_score.pt' tic = time.time() logger.info("Saving score matrix: {} ...".format(save_fname)) torch.save(metric, save_fname) logger.info(f"Done in {time.time() - tic:.3f}s")
def train(config): """Cross-modal architecture training.""" # Get the list of experts and their dimensions expert_dims = compute_dims(config) raw_input_dims = {} for expert, expert_dic in expert_dims.items(): raw_input_dims[expert] = expert_dic["dim"] # Set the random initial seeds tic = time.time() seed = config["seed"] cross_seed = config.get("cross_seed", seed) logger.debug("Setting experiment random seed to %d", seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # Tokenizer to parse sentences into tokens tokenizer = create_tokenizer(config["arch"]["args"]["txt_inp"]) # Create the datasets logger.info("Preparing the dataloaders ...") dataset_types = ["train_sets", "continuous_eval_sets", "final_eval_sets"] data_loaders = {} loaded_data = {} for dataset_type in dataset_types: training = dataset_type == "train_sets" if not config.get(dataset_type, False): continue data_loaders[dataset_type] = [] for _, data_loader in enumerate(config[dataset_type]): data_loaders[dataset_type].append( getattr(module_data, data_loader["type"])( **data_loader["args"], raw_input_dims=raw_input_dims, training=training, tokenizer=tokenizer, loaded_data=loaded_data, cross_seed=cross_seed, )) # Setup the cross-modal architecture model = config.init( name="arch", module=module_arch, expert_dims=expert_dims, tokenizer=tokenizer, ) loss = config.init(name="loss", module=module_loss) metrics = [getattr(module_metric, met) for met in config["metrics"]] trainable_params = filter(lambda p: p.requires_grad, model.parameters()) if config["optimizer"]["type"] == "Ranger": optimizer = config.init("optimizer", ranger, trainable_params) else: optimizer = config.init("optimizer", torch.optim, trainable_params) lr_scheduler = config.init("lr_scheduler", torch.optim.lr_scheduler, optimizer) if "warmup_iterations" in config["optimizer"]: warmup_iterations = config["optimizer"]["warmup_iterations"] else: warmup_iterations = -1 visualizer = config.init( name="visualizer", module=module_vis, exp_name=config.exper_name, web_dirs=config.web_dirs, ) trainer = Trainer( model, loss, metrics, optimizer, config=config, data_loaders=data_loaders, lr_scheduler=lr_scheduler, visualizer=visualizer, skip_first_n_saves=config["trainer"].get("skip_first_n_saves", 0), include_optim_in_ckpts=config["trainer"].get("include_optim_in_ckpts", False), expert_dims=expert_dims, tokenizer=tokenizer, warmup_iterations=warmup_iterations) if not config.only_eval: logger.info("Training ...") trainer.train() logger.info("Final evaluation ...") trainer.evaluate() duration = time.strftime("%Hh%Mm%Ss", time.gmtime(time.time() - tic)) logger.info("Script took %s", duration) # Report the location of the "best" checkpoint of the final seeded run (here # "best" corresponds to the model with the highest geometric mean over the # R@1, R@5 and R@10 metrics when a validation set is used, or simply the final # epoch of training for fixed-length schedules). best_ckpt_path = config.save_dir / "trained_model.pth" if os.path.exists(best_ckpt_path): logger.info("The best performing ckpt can be found at %s", str(best_ckpt_path))
def evaluation(config, logger=None, trainer=None): if logger is None: logger = config.get_logger('test') if getattr(config._args, "eval_from_training_config", False): eval_conf = copy.deepcopy(config) merge(eval_conf._config, config["eval_settings"], strategy=Strategy.REPLACE) config = eval_conf logger.info("Running evaluation with configuration:") logger.info(config) expert_dims, raw_input_dims = compute_dims(config) trn_config = compute_trn_config(config) # Set the random initial seeds seed = config["seed"] logger.info(f"Setting experiment random seed to {seed}") random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # We use cls defaults for backwards compatibility with the MMIT configs. In the # long run this should be handled by the json configs themselves cls_defaults = ["train", "val", "tiny", "challenge"] data_loaders = config.init( name='data_loader', module=module_data, logger=logger, raw_input_dims=raw_input_dims, text_feat=config["experts"]["text_feat"], text_dim=config["experts"]["text_dim"], text_agg=config["experts"]["text_agg"], use_zeros_for_missing=config["experts"].get("use_zeros_for_missing", False), task=config.get("task", "retrieval"), cls_partitions=config.get("cls_partitions", cls_defaults), ) model = config.init( name='arch', module=module_arch, trn_config=trn_config, expert_dims=expert_dims, text_dim=config["experts"]["text_dim"], disable_nan_checks=config["disable_nan_checks"], task=config.get("task", "retrieval"), ce_shared_dim=config["experts"].get("ce_shared_dim", None), feat_aggregation=config["data_loader"]["args"]["feat_aggregation"], trn_cat=config["data_loader"]["args"].get("trn_cat", 0), ) logger.info(model) metrics = [getattr(module_metric, met) for met in config['metrics']] visualizer = config.init( name='visualizer', module=module_vis, exp_name=config._exper_name, web_dir=config._web_log_dir, ) ckpt_path = config._args.resume logger.info(f"Loading checkpoint: {ckpt_path} ...") checkpoint = torch.load(ckpt_path) state_dict = checkpoint['state_dict'] if config['n_gpu'] > 1: model = torch.nn.DataParallel(model) model.load_state_dict(state_dict) # prepare model for testing. Note that some datasets fail to fit the retrieval # set on the GPU, so we run them on the CPU if torch.cuda.is_available() and not config.get("disable_gpu", True): device = "cuda" else: device = "cpu" logger.info(f"Running evaluation on {device}") model = model.to(device) model.eval() with torch.no_grad(): samples, meta = data_loaders["retrieval"] # To use the nan-checks safely, we need make temporary copies of the data disable_nan_checks = config._config["disable_nan_checks"] with ctxt_mgr(samples, device, disable_nan_checks) as valid: output = model(**valid) sims = output["cross_view_conf_matrix"].data.cpu().float().numpy() dataset = data_loaders.dataset_name nested_metrics = {} for metric in metrics: metric_name = metric.__name__ res = metric(sims, query_masks=meta["query_masks"]) verbose(epoch=0, metrics=res, name=dataset, mode=metric_name) if trainer is not None: if not trainer.mini_train: trainer.writer.set_step(step=0, mode="val") # avoid tensboard folding by prefixing metric_name_ = f"test_{metric_name}" trainer.log_metrics(res, metric_name=metric_name_, mode="val") nested_metrics[metric_name] = res if data_loaders.num_test_captions == 1: visualizer.visualize_ranking( sims=sims, meta=meta, epoch=0, nested_metrics=nested_metrics, ) log = {} for subkey, subval in nested_metrics.items(): for subsubkey, subsubval in subval.items(): log[f"test_{subkey}_{subsubkey}"] = subsubval for key, value in log.items(): logger.info(" {:15s}: {}".format(str(key), value))
def run_exp(config): warnings.filterwarnings('ignore') logger = config.get_logger('train') leaderboard_path = config._args.leaderboard Path(leaderboard_path).parent.mkdir(exist_ok=True, parents=True) with open(leaderboard_path, 'a') as f: txt_path = f"{config._log_dir}/preds.txt" print(txt_path, file=f, flush=True) expert_dims, raw_input_dims = compute_dims(config, logger) trn_config = compute_trn_config(config) if config._args.group_seed: seeds = [int(config._args.group_seed)] else: seeds = [int(x) for x in config._args.seeds.split(",")] # set up local filesystem on the cluster if socket.gethostname().endswith("cluster"): os.system(str(Path.home() / "configure_tmp_data.sh")) for ii, seed in enumerate(seeds): tic = time.time() logger.info(f"{ii + 1}/{len(seeds)} Setting experiment random seed to {seed}") set_seeds(seed) config["seed"] = seed # We use cls defaults for backwards compatibility with the MMIT configs. In the # long run this should be handled by the json configs themselves cls_defaults = ["train", "val", "tiny", "challenge"] model = config.init( name='arch', module=module_arch, expert_dims=expert_dims, text_dim=config["experts"]["text_dim"], disable_nan_checks=config["disable_nan_checks"], spatial_feats=config["data_loader"]["args"].get("spatial_feats", False), task=config.get("task", "retrieval"), ce_shared_dim=config["experts"].get("ce_shared_dim", None), feat_aggregation=config["data_loader"]["args"]["feat_aggregation"], trn_config=trn_config, trn_cat=config["data_loader"]["args"].get("trn_cat", 0), ) logger.info(model) data_loaders = config.init( name='data_loader', module=module_data, logger=logger, raw_input_dims=raw_input_dims, text_feat=config["experts"]["text_feat"], text_dim=config["experts"]["text_dim"], text_agg=config["experts"]["text_agg"], use_zeros_for_missing=config["experts"].get("use_zeros_for_missing", False), task=config.get("task", "retrieval"), cls_partitions=config.get("cls_partitions", cls_defaults) ) if config.get("manual_linear_init", False): logger.info("manually setting init for linear layers") def init_weights(m): if isinstance(m, nn.Linear): torch.nn.init.xavier_uniform(m.weight) m.bias.data.fill_(0.01) model.apply(init_weights) loss = config.init(name="loss", module=module_loss) metrics = [getattr(module_metric, met) for met in config['metrics']] trainable_params = filter(lambda p: p.requires_grad, model.parameters()) if config["optimizer"]["type"] == "RAdam": optimizer = config.init('optimizer', radam, trainable_params) elif config["optimizer"]["type"] == "Ranger": optimizer = config.init('optimizer', ranger, trainable_params) elif config["optimizer"]["type"] == "SWATS": optimizer = config.init('optimizer', swats, trainable_params) else: optimizer = config.init('optimizer', torch.optim, trainable_params) if config["lr_scheduler"]["type"] == "StepLR": lr_scheduler = config.init('lr_scheduler', torch.optim.lr_scheduler, optimizer) else: lr_scheduler = config.init('lr_scheduler', cos_restart, optimizer) visualizer = config.init( name='visualizer', module=module_vis, exp_name=config._exper_name, web_dir=config._web_log_dir, ) trainer = Trainer( model, loss, metrics, optimizer, config=config, data_loaders=data_loaders, lr_scheduler=lr_scheduler, mini_train=config._args.mini_train, disable_nan_checks=config["disable_nan_checks"], visualizer=visualizer, val_freq=config["trainer"].get("val_freq", 1), force_cpu_val=config.get("force_cpu_val", False), skip_first_n_saves=config["trainer"].get("skip_first_n_saves", 0), include_optim_in_ckpts=config["trainer"].get("include_optim_in_ckpts", 1), cache_targets=set(config.get("cache_targets", [])), ) trainer.train() best_ckpt_path = config.save_dir / "trained_model.pth" duration = time.strftime('%Hh%Mm%Ss', time.gmtime(time.time() - tic)) logger.info(f"Training took {duration}") if config._config.get("eval_settings", False): eval_config = copy.deepcopy(config) merge(eval_config._config, config["eval_settings"], strategy=Strategy.REPLACE) eval_config._args.resume = best_ckpt_path evaluation(eval_config, logger=logger, trainer=trainer) # If multiple runs were conducted, report relevant statistics if len(seeds) > 1: log_summary( logger=logger, log_path=config.log_path, eval_mode=config["eval_mode"], fixed_num_epochs=config["trainer"]["epochs"], ) print(f"Log file stored at {config.log_path}") # Report the location of the "best" checkpoint of the final seeded run (here # "best" corresponds to the model with the highest geometric mean over the # R@1, R@5 and R@10 metrics when a validation set is used, or simply the final # epoch of training for fixed-length schedules). print(f"The best performing ckpt can be found at {str(best_ckpt_path)}")
def main(config): logger = config.get_logger('train') expert_dims, raw_input_dims = compute_dims(config, logger) seeds = [int(x) for x in config._args.seeds.split(",")] for seed in seeds: # Set the random initial seeds tic = time.time() logger.info(f"Setting experiment random seed to {seed}") random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) data_loaders = config.init( name='data_loader', module=module_data, raw_input_dims=raw_input_dims, text_feat=config["experts"]["text_feat"], text_dim=config["experts"]["text_dim"], ) model = config.init( name='arch', module=module_arch, expert_dims=expert_dims, text_dim=config["experts"]["text_dim"], disable_nan_checks=config["disable_nan_checks"], ) logger.info(model) loss = config.init(name="loss", module=module_loss) metrics = [getattr(module_metric, met) for met in config['metrics']] trainable_params = filter(lambda p: p.requires_grad, model.parameters()) optimizer = config.init('optimizer', torch.optim, trainable_params) lr_scheduler = config.init('lr_scheduler', torch.optim.lr_scheduler, optimizer) visualizer = config.init( name='visualizer', module=module_vis, exp_name=config._exper_name, log_dir=config._web_log_dir, ) trainer = Trainer( model, loss, metrics, optimizer, config=config, data_loaders=data_loaders, lr_scheduler=lr_scheduler, mini_train=config._args.mini_train, disable_nan_checks=config["disable_nan_checks"], visualizer=visualizer, skip_first_n_saves=config["trainer"].get("skip_first_n_saves", 0), include_optim_in_ckpts=config["trainer"].get( "include_optim_in_ckpts", False), ) trainer.train() best_ckpt_path = config.save_dir / "trained_model.pth" duration = time.strftime('%Hh%Mm%Ss', time.gmtime(time.time() - tic)) logger.info(f"Training took {duration}") # If the dataset supports separate validation/test splits, the training config # json should specify an `eval_config` entry with the path to the test # configuration if config._config.get("eval_config", False): eval_args = argparse.ArgumentParser() eval_args.add_argument("--config", default=config["eval_config"]) eval_args.add_argument("--device", default=config._args.device) eval_args.add_argument("--resume", default=best_ckpt_path) eval_config = ConfigParser(eval_args, slave_mode=True) evaluation(eval_config, logger=logger) # If multiple runs were conducted, report relevant statistics if len(seeds) > 1: log_summary( logger=logger, log_path=config.log_path, eval_mode=config["eval_mode"], fixed_num_epochs=config["trainer"]["epochs"], ) print(f"Log file stored at {config.log_path}") # Report the location of the "best" checkpoint of the final seeded run (here # "best" corresponds to the model with the highest geometric mean over the # R@1, R@5 and R@10 metrics when a validation set is used, or simply the final # epoch of training for fixed-length schedules). print(f"The best performing ckpt can be found at {str(best_ckpt_path)}")
def test(config): logger = config.get_logger('test') logger.info("Running test with configuration:") logger.info(config) expert_dims = compute_dims(config) vocab = None vocab_size = None we_parameter = None if "attr" in config['experts']['modalities']: attr_vocab = Vocabulary() attr_vocab.load( os.path.join(config['data_loader']['args']['data_dir'], 'attributes/dict.attr.json')) attr_vocab_size = len(attr_vocab) else: attr_vocab = None attr_vocab_size = None data_loaders = config.init( name='data_loader', module=module_data, expert_dims=expert_dims, text_feat=config['experts']['text_feat'], text_dim=config['experts']['text_dim'], ) model = config.init(name='arch', module=module_arch, expert_dims=expert_dims, text_dim=config['experts']['text_dim'], same_dim=config['experts']['ce_shared_dim'], text_feat=config['experts']['text_feat']) trainer = TrainerJoint( model, loss=None, optimizer=None, config=config, data_loaders=data_loaders, lr_scheduler=None, ) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger.info(f"Running test on {device}") metric = trainer._valid_epoch(save_textatt=True) if config._args.mode == 'val': for key, value in metric.items(): if key == 'recall_avg': logger.info(f'[Avg Recall] : {value}') elif key == 'recall_avg_corr': logger.info(f'[Avg Recall corr]: {value}') elif key == 'comb_avg': logger.info(f'[comb_avg] : {value}') elif key == 'recall': for i, category in zip(value, trainer.categories): if len(i) == 2: logger.info(f'[{category}] r@10, r@50: {i[0]}\t{i[1]}') elif len(i) == 4: logger.info( f'[{category}] comp corr r@10, r@50: {i[0]}\t{i[1]}\t{i[2]}\t{i[3]}' ) elif key == 'comb': combstr = "comb:" for i, category in zip(value, trainer.categories): combstr += f' {i[0]} {i[1]}' logger.info(combstr) else: save_fname = config.save_dir / f'test_score.pt' tic = time.time() logger.info("Saving score matrix: {} ...".format(save_fname)) torch.save(metric, save_fname) logger.info(f"Done in {time.time() - tic:.3f}s")
def evaluation(config, logger=None): if logger is None: logger = config.get_logger('test') logger.info("Running evaluation with configuration:") logger.info(config) expert_dims, raw_input_dims = compute_dims(config) # Set the random initial seeds seed = config["seed"] logger.info(f"Setting experiment random seed to {seed}") random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) data_loaders = config.init( name='data_loader', module=module_data, raw_input_dims=raw_input_dims, text_feat=config["experts"]["text_feat"], text_dim=config["experts"]["text_dim"], ) model = config.init( name='arch', module=module_arch, expert_dims=expert_dims, text_dim=config["experts"]["text_dim"], disable_nan_checks=config["disable_nan_checks"], ) logger.info(model) metrics = [getattr(module_metric, met) for met in config['metrics']] visualizer = config.init( name='visualizer', module=module_vis, exp_name=config._exper_name, log_dir=config._web_log_dir, ) ckpt_path = config._args.resume logger.info(f"Loading checkpoint: {ckpt_path} ...") checkpoint = torch.load(ckpt_path) state_dict = checkpoint['state_dict'] if config['n_gpu'] > 1: model = torch.nn.DataParallel(model) model.load_state_dict(state_dict) # prepare model for testing. Note that some datasets fail to fit the retrieval # set on the GPU, so we run them on the CPU if torch.cuda.is_available() and not config.get("disable_gpu", False): device = "cuda" else: device = "cpu" logger.info(f"Running evaluation on {device}") model = model.to(device) model.eval() with torch.no_grad(): samples, meta = data_loaders["retrieval"] # To use the nan-checks safely, we need make temporary copies of the data disable_nan_checks = config._config["disable_nan_checks"] with valid_samples(samples, device, disable_nan_checks) as valid: output = model(**valid) sims = output["cross_view_conf_matrix"].data.cpu().float().numpy() dataset = data_loaders.dataset_name nested_metrics = {} for metric in metrics: metric_name = metric.__name__ res = metric(sims, query_masks=meta["query_masks"]) verbose(epoch=0, metrics=res, name=dataset, mode=metric_name) nested_metrics[metric_name] = res if data_loaders.num_test_captions == 1: visualizer.visualize_ranking( sims=sims, meta=meta, epoch=0, nested_metrics=nested_metrics, ) log = {} for subkey, subval in nested_metrics.items(): for subsubkey, subsubval in subval.items(): log[f"test_{subkey}_{subsubkey}"] = subsubval for key, value in log.items(): logger.info(' {:15s}: {}'.format(str(key), value))