Пример #1
0
def get_model_and_data_loaders(
    config: ConfigParser,
    logger: logging.Logger,
    ckpt_path: Path,
) -> Tuple[torch.nn.Module, module_data.ExpertDataLoader]:
    expert_dims, raw_input_dims, text_dim = compute_dims(config)

    data_loaders = config.init(
        name='data_loader',
        module=module_data,
        logger=logger,
        raw_input_dims=raw_input_dims,
        challenge_mode=config.get("challenge_mode", False),
        text_dim=text_dim,
        text_feat=config["experts"]["text_feat"],
        text_agg=config["experts"]["text_agg"],
        use_zeros_for_missing=config["experts"].get("use_zeros_for_missing",
                                                    False),
        task=config.get("task", "retrieval"),
        eval_only=True,
        distil_params=config.get("distil_params", None),
        training_file=config.get("training_file", None),
        caption_masks=config.get("caption_masks", None),
        ce_shared_dim=config["experts"].get("ce_shared_dim", None),
    )

    trn_config = compute_trn_config(config)
    model = config.init(
        name='arch',
        module=module_arch,
        trn_config=trn_config,
        expert_dims=expert_dims,
        text_dim=text_dim,
        disable_nan_checks=config["disable_nan_checks"],
        task=config.get("task", "retrieval"),
        ce_shared_dim=config["experts"].get("ce_shared_dim", None),
        feat_aggregation=config["data_loader"]["args"]["feat_aggregation"],
        trn_cat=config["data_loader"]["args"].get("trn_cat", 0),
    )
    ckpt_path = config._args.resume
    logger.info(f"Loading checkpoint: {ckpt_path} ...")
    checkpoint = torch.load(ckpt_path)
    state_dict = checkpoint['state_dict']
    if config['n_gpu'] > 1:
        model = torch.nn.DataParallel(model)
    # support backwards compatibility
    deprecated = ["ce.moe_fc_bottleneck1", "ce.moe_cg", "ce.moe_fc_proj"]
    for mod in deprecated:
        for suffix in ("weight", "bias"):
            key = f"{mod}.{suffix}"
            if key in state_dict:
                print(f"WARNING: Removing deprecated key {key} from model")
                state_dict.pop(key)
    model.load_state_dict(state_dict)

    return model, data_loaders
Пример #2
0
def get_model_and_data_loaders(
        config: ConfigParser,
        logger: logging.Logger,
        ckpt_path: Path,
) -> Tuple[torch.nn.Module, module_data.ExpertDataLoader]:
    expert_dims, raw_input_dims = compute_dims(config)
    trn_config = compute_trn_config(config)

    data_loaders = config.init(
        name='data_loader',
        module=module_data,
        logger=logger,
        raw_input_dims=raw_input_dims,
        challenge_mode=config.get("challenge_mode", False),
        text_feat=config["experts"]["text_feat"],
        text_dim=config["experts"]["text_dim"],
        text_agg=config["experts"]["text_agg"],
        use_zeros_for_missing=config["experts"].get("use_zeros_for_missing", False),
        task=config.get("task", "retrieval"),
        eval_only=True,
    )

    model = config.init(
        name='arch',
        module=module_arch,
        trn_config=trn_config,
        expert_dims=expert_dims,
        text_dim=config["experts"]["text_dim"],
        disable_nan_checks=config["disable_nan_checks"],
        task=config.get("task", "retrieval"),
        ce_shared_dim=config["experts"].get("ce_shared_dim", None),
        feat_aggregation=config["data_loader"]["args"]["feat_aggregation"],
        trn_cat=config["data_loader"]["args"].get("trn_cat", 0),
    )
    ckpt_path = config._args.resume
    logger.info(f"Loading checkpoint: {ckpt_path} ...")
    checkpoint = torch.load(ckpt_path)
    state_dict = checkpoint['state_dict']
    if config['n_gpu'] > 1:
        model = torch.nn.DataParallel(model)
    model.load_state_dict(state_dict)

    return model, data_loaders
Пример #3
0
def evaluation(config, logger=None, trainer=None):

    if logger is None:
        logger = config.get_logger('test')

    if getattr(config._args, "eval_from_training_config", False):
        eval_conf = copy.deepcopy(config)
        merge(eval_conf._config,
              config["eval_settings"],
              strategy=Strategy.REPLACE)
        config = eval_conf

    logger.info("Running evaluation with configuration:")
    logger.info(config)

    expert_dims, raw_input_dims = compute_dims(config)
    trn_config = compute_trn_config(config)

    # Set the random initial seeds
    seed = config["seed"]
    logger.info(f"Setting experiment random seed to {seed}")
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    update_src_web_video_dir(config)
    visualizer = config.init(
        name='visualizer',
        module=module_vis,
        exp_name=config._exper_name,
        web_dir=config._web_log_dir,
    )

    data_loaders = config.init(
        name='data_loader',
        module=module_data,
        logger=logger,
        raw_input_dims=raw_input_dims,
        challenge_mode=config.get("challenge_mode", False),
        text_feat=config["experts"]["text_feat"],
        text_dim=config["experts"]["text_dim"],
        text_agg=config["experts"]["text_agg"],
        use_zeros_for_missing=config["experts"].get("use_zeros_for_missing",
                                                    False),
        task=config.get("task", "retrieval"),
        eval_only=True,
    )

    model = config.init(
        name='arch',
        module=module_arch,
        trn_config=trn_config,
        expert_dims=expert_dims,
        text_dim=config["experts"]["text_dim"],
        disable_nan_checks=config["disable_nan_checks"],
        task=config.get("task", "retrieval"),
        ce_shared_dim=config["experts"].get("ce_shared_dim", None),
        feat_aggregation=config["data_loader"]["args"]["feat_aggregation"],
        trn_cat=config["data_loader"]["args"].get("trn_cat", 0),
    )
    logger.info(model)

    metrics = [getattr(module_metric, met) for met in config['metrics']]
    ckpt_path = config._args.resume
    logger.info(f"Loading checkpoint: {ckpt_path} ...")
    checkpoint = torch.load(ckpt_path)
    state_dict = checkpoint['state_dict']
    if config['n_gpu'] > 1:
        model = torch.nn.DataParallel(model)
    model.load_state_dict(state_dict)
    challenge_mode = config.get("challenge_mode", False)
    challenge_msg = (
        "\n"
        "Evaluation ran on challenge features. To obtain a score, upload the similarity"
        "matrix for each dataset to the test server after running the "
        "`misc/cvpr2020-challenge/prepare_submission.py` script and following the "
        "instructions at: "
        "https://www.robots.ox.ac.uk/~vgg/challenges/video-pentathlon/"
        "\n")

    # prepare model for testing.  Note that some datasets fail to fit the retrieval
    # set on the GPU, so we run them on the CPU
    if torch.cuda.is_available() and not config.get("disable_gpu", True):
        device = "cuda"
    else:
        device = "cpu"
    logger.info(f"Running evaluation on {device}")

    model = model.to(device)
    model.eval()

    with torch.no_grad():
        samples, meta = data_loaders["retrieval"]

        # To use the nan-checks safely, we need make temporary copies of the data
        disable_nan_checks = config._config["disable_nan_checks"]
        with ctxt_mgr(samples, device, disable_nan_checks) as valid:
            output = model(**valid)

        sims = output["cross_view_conf_matrix"].data.cpu().float().numpy()
        dataset = data_loaders.dataset_name
        if challenge_mode:
            split = data_loaders.dataloaders["dataset"].split_name
            prediction_path = config._log_dir / f"{dataset}-{split}-predictions.csv"
            compressed_preds = compress_predictions(
                query_masks=meta["query_masks"],
                sims=sims,
            )
            np.savetxt(prediction_path,
                       compressed_preds,
                       delimiter=',',
                       fmt="%d")
            print(f"Saved similarity matrix predictions to {prediction_path}")
            print(challenge_msg)
            return

        nested_metrics = {}
        for metric in metrics:
            metric_name = metric.__name__
            res = metric(sims, query_masks=meta["query_masks"])
            verbose(epoch=0, metrics=res, name=dataset, mode=metric_name)
            if trainer is not None:
                if not trainer.mini_train:
                    trainer.writer.set_step(step=0, mode="val")
                # avoid tensboard folding by prefixing
                metric_name_ = f"test_{metric_name}"
                trainer.log_metrics(res, metric_name=metric_name_, mode="val")
            nested_metrics[metric_name] = res

    if data_loaders.num_test_captions == 1:
        visualizer.visualize_ranking(
            sims=sims,
            meta=meta,
            epoch=0,
            nested_metrics=nested_metrics,
        )
    log = {}
    for subkey, subval in nested_metrics.items():
        for subsubkey, subsubval in subval.items():
            log[f"test_{subkey}_{subsubkey}"] = subsubval
    for key, value in log.items():
        logger.info(" {:15s}: {}".format(str(key), value))
Пример #4
0
def test(config):
    config.config['data_loader']['args']['mode'] = 'test'
    logger = config.get_logger('test')
    logger.info("Running test with configuration:")
    logger.info(config)

    expert_dims, raw_input_dims = compute_dims(config)

    if config['experts']['text_feat'] == 'learnable':
        # vocab
        vocab = Vocabulary()
        vocab.load('dataset/captions/dict.all_200k_gan.json')
        vocab_size = len(vocab)

        # word2vec
        if config['experts']['text_feat_init'] == True:
            # word2vec, download file and move to we_root-path directory
            # https://www.kaggle.com/jacksoncrow/word2vec-flickr30k/version/1
            we_rootpath = '/home/yj/pretrained_model'
            w2v_data_path = os.path.join(we_rootpath, "word2vec/", 'flickr',
                                         'vec500flickr30m')
            we_parameter = get_we_parameter(vocab, w2v_data_path)
        else:
            we_parameter = None
    else:
        vocab = None
        vocab_size = None
        we_parameter = None

    if "attr" in config['experts']['modalities']:
        attr_vocab = Vocabulary()
        attr_vocab.load('dataset/captions/dict.attr.json')
        attr_vocab_size = len(attr_vocab)
    else:
        attr_vocab = None
        attr_vocab_size = None

    data_loaders = config.init(name='data_loader',
                               module=module_data,
                               raw_input_dims=raw_input_dims,
                               text_feat=config['experts']['text_feat'],
                               text_dim=config['experts']['text_dim'],
                               vocab=vocab,
                               attr_vocab=attr_vocab,
                               pretrain=config['trainer']['pretrain'])

    model = config.init(name='arch',
                        module=module_arch,
                        expert_dims=expert_dims,
                        text_dim=config['experts']['text_dim'],
                        same_dim=config['experts']['ce_shared_dim'],
                        we_parameter=we_parameter,
                        vocab_size=vocab_size,
                        attr_vocab_size=attr_vocab_size,
                        text_feat=config['experts']['text_feat'])

    ckpt_path = Path(config._args.resume)
    logger.info(f"Loading checkpoint: {ckpt_path} ...")
    checkpoint = torch.load(ckpt_path)
    state_dict = checkpoint['state_dict']
    if config['n_gpu'] > 1:
        model = torch.nn.DataParallel(model)
    model.load_state_dict(state_dict)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f"Running test on {device}")

    model = model.to(device)
    model.eval()

    categories = ['dress', 'shirt', 'toptee']
    modalities = data_loaders[categories[0]].dataset.ordered_experts
    metric = {'score': dict()}

    for i, category in enumerate(categories):
        val_experts = {expert: list() for expert in modalities}
        target_ind = {expert: list() for expert in modalities}
        data_asin = []

        for batch in data_loaders[category + '_trg']:
            for key, val in batch['candidate_experts'].items():
                batch['candidate_experts'][key] = val.to(device)

            data_asin.extend(
                [meta['candidate'] for meta in batch['meta_info']])

            for key, val in batch['candidate_ind'].items():
                target_ind[key].append(val)

            with torch.no_grad():
                experts, _, _ = model(batch['candidate_experts'],
                                      batch['candidate_ind'],
                                      target=True)
                for modality, val in experts.items():
                    val_experts[modality].append(val)

        for modality, val in val_experts.items():
            val_experts[modality] = torch.cat(val)

        for modality, val in target_ind.items():
            target_ind[modality] = torch.cat(val)

        scores = []
        meta_infos = []
        val_size = val_experts['resnet'].size(0)

        for batch in data_loaders[category]:
            for experts in ['candidate_experts']:
                for key, val in batch[experts].items():
                    batch[experts][key] = val.to(device)
            batch["text"] = batch["text"].to(device)
            batch_size = batch["text"].size(0)

            meta_infos.extend(list(batch['meta_info']))

            with torch.no_grad():
                # composition_feature, text, moe_weights = model(batch['candidate_experts'],
                #                                                batch['candidate_ind'],
                #                                                batch['text'],
                #                                                batch['text_bow'],
                #                                                batch['text_lengths'])

                # batch_target = dict()
                # for mod in modalities:
                #     tmp = []
                #     for k in range(batch_size):
                #         tmp.append(model.target_composition(val_experts[mod], text[mod][k].expand(val_size, -1)))
                #     batch_target[mod] = torch.stack(tmp)

                src_experts = model.image_encoder(batch['candidate_experts'],
                                                  batch['candidate_ind'])
                src_text, moe_weights = model.get_text_feature(
                    batch['text'], batch['candidate_ind'], batch['text_bow'],
                    batch['text_lengths'])
                src_feature = model.get_combined_feature(src_experts, src_text)

                trg_text, _ = model.get_text_feature(batch['text'],
                                                     batch['target_ind'],
                                                     batch['text_bow'],
                                                     batch['text_lengths'],
                                                     target=True)
                # trg_text, _ = self.model.text_encoder['trg'](batch['text_mean'].unsqueeze(1), batch['target_ind'])

                batch_target = dict()
                for h, mod in enumerate(modalities):
                    tmp = []
                    for k in range(batch_size):
                        tmp.append(
                            model.trg_normalization_layer(
                                model.target_composition[h](
                                    val_experts[mod],
                                    trg_text[mod][k].expand(val_size, -1))))
                    batch_target[mod] = torch.stack(tmp)

                cross_view_conf_matrix = sharded_cross_view_inner_product(
                    vid_embds=batch_target,
                    text_embds=src_feature,
                    text_weights=moe_weights,
                    subspaces=model.image_encoder.modalities,
                    l2renorm=True,
                    dist=True,
                    val=True)

                scores.append(cross_view_conf_matrix)
        scores = torch.cat(scores)
        val_ids = data_loaders[category + '_trg'].dataset.data
        assert val_ids == data_asin
        metric['score'][category] = {
            'ids': val_ids,
            'matrix': scores,
            'meta_info': meta_infos
        }

    save_fname = ckpt_path.parent / f'test_score.pt'
    tic = time.time()
    logger.info("Saving score matrix: {} ...".format(save_fname))
    torch.save(metric, save_fname)
    logger.info(f"Done in {time.time() - tic:.3f}s")
Пример #5
0
def train(config):
    """Cross-modal architecture training."""

    # Get the list of experts and their dimensions
    expert_dims = compute_dims(config)
    raw_input_dims = {}
    for expert, expert_dic in expert_dims.items():
        raw_input_dims[expert] = expert_dic["dim"]

    # Set the random initial seeds
    tic = time.time()
    seed = config["seed"]
    cross_seed = config.get("cross_seed", seed)
    logger.debug("Setting experiment random seed to %d", seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    # Tokenizer to parse sentences into tokens
    tokenizer = create_tokenizer(config["arch"]["args"]["txt_inp"])

    # Create the datasets
    logger.info("Preparing the dataloaders ...")
    dataset_types = ["train_sets", "continuous_eval_sets", "final_eval_sets"]
    data_loaders = {}
    loaded_data = {}
    for dataset_type in dataset_types:
        training = dataset_type == "train_sets"
        if not config.get(dataset_type, False):
            continue
        data_loaders[dataset_type] = []
        for _, data_loader in enumerate(config[dataset_type]):
            data_loaders[dataset_type].append(
                getattr(module_data, data_loader["type"])(
                    **data_loader["args"],
                    raw_input_dims=raw_input_dims,
                    training=training,
                    tokenizer=tokenizer,
                    loaded_data=loaded_data,
                    cross_seed=cross_seed,
                ))

    # Setup the cross-modal architecture
    model = config.init(
        name="arch",
        module=module_arch,
        expert_dims=expert_dims,
        tokenizer=tokenizer,
    )

    loss = config.init(name="loss", module=module_loss)
    metrics = [getattr(module_metric, met) for met in config["metrics"]]
    trainable_params = filter(lambda p: p.requires_grad, model.parameters())

    if config["optimizer"]["type"] == "Ranger":
        optimizer = config.init("optimizer", ranger, trainable_params)
    else:
        optimizer = config.init("optimizer", torch.optim, trainable_params)

    lr_scheduler = config.init("lr_scheduler", torch.optim.lr_scheduler,
                               optimizer)

    if "warmup_iterations" in config["optimizer"]:
        warmup_iterations = config["optimizer"]["warmup_iterations"]
    else:
        warmup_iterations = -1

    visualizer = config.init(
        name="visualizer",
        module=module_vis,
        exp_name=config.exper_name,
        web_dirs=config.web_dirs,
    )

    trainer = Trainer(
        model,
        loss,
        metrics,
        optimizer,
        config=config,
        data_loaders=data_loaders,
        lr_scheduler=lr_scheduler,
        visualizer=visualizer,
        skip_first_n_saves=config["trainer"].get("skip_first_n_saves", 0),
        include_optim_in_ckpts=config["trainer"].get("include_optim_in_ckpts",
                                                     False),
        expert_dims=expert_dims,
        tokenizer=tokenizer,
        warmup_iterations=warmup_iterations)

    if not config.only_eval:
        logger.info("Training ...")
        trainer.train()
    logger.info("Final evaluation ...")
    trainer.evaluate()
    duration = time.strftime("%Hh%Mm%Ss", time.gmtime(time.time() - tic))
    logger.info("Script took %s", duration)

    # Report the location of the "best" checkpoint of the final seeded run (here
    # "best" corresponds to the model with the highest geometric mean over the
    # R@1, R@5 and R@10 metrics when a validation set is used, or simply the final
    # epoch of training for fixed-length schedules).
    best_ckpt_path = config.save_dir / "trained_model.pth"
    if os.path.exists(best_ckpt_path):
        logger.info("The best performing ckpt can be found at %s",
                    str(best_ckpt_path))
Пример #6
0
def evaluation(config, logger=None, trainer=None):

    if logger is None:
        logger = config.get_logger('test')

    if getattr(config._args, "eval_from_training_config", False):
        eval_conf = copy.deepcopy(config)
        merge(eval_conf._config,
              config["eval_settings"],
              strategy=Strategy.REPLACE)
        config = eval_conf

    logger.info("Running evaluation with configuration:")
    logger.info(config)

    expert_dims, raw_input_dims = compute_dims(config)
    trn_config = compute_trn_config(config)

    # Set the random initial seeds
    seed = config["seed"]
    logger.info(f"Setting experiment random seed to {seed}")
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    # We use cls defaults for backwards compatibility with the MMIT configs.  In the
    # long run this should be handled by the json configs themselves
    cls_defaults = ["train", "val", "tiny", "challenge"]

    data_loaders = config.init(
        name='data_loader',
        module=module_data,
        logger=logger,
        raw_input_dims=raw_input_dims,
        text_feat=config["experts"]["text_feat"],
        text_dim=config["experts"]["text_dim"],
        text_agg=config["experts"]["text_agg"],
        use_zeros_for_missing=config["experts"].get("use_zeros_for_missing",
                                                    False),
        task=config.get("task", "retrieval"),
        cls_partitions=config.get("cls_partitions", cls_defaults),
    )

    model = config.init(
        name='arch',
        module=module_arch,
        trn_config=trn_config,
        expert_dims=expert_dims,
        text_dim=config["experts"]["text_dim"],
        disable_nan_checks=config["disable_nan_checks"],
        task=config.get("task", "retrieval"),
        ce_shared_dim=config["experts"].get("ce_shared_dim", None),
        feat_aggregation=config["data_loader"]["args"]["feat_aggregation"],
        trn_cat=config["data_loader"]["args"].get("trn_cat", 0),
    )
    logger.info(model)

    metrics = [getattr(module_metric, met) for met in config['metrics']]
    visualizer = config.init(
        name='visualizer',
        module=module_vis,
        exp_name=config._exper_name,
        web_dir=config._web_log_dir,
    )
    ckpt_path = config._args.resume
    logger.info(f"Loading checkpoint: {ckpt_path} ...")
    checkpoint = torch.load(ckpt_path)
    state_dict = checkpoint['state_dict']
    if config['n_gpu'] > 1:
        model = torch.nn.DataParallel(model)
    model.load_state_dict(state_dict)

    # prepare model for testing.  Note that some datasets fail to fit the retrieval
    # set on the GPU, so we run them on the CPU
    if torch.cuda.is_available() and not config.get("disable_gpu", True):
        device = "cuda"
    else:
        device = "cpu"
    logger.info(f"Running evaluation on {device}")

    model = model.to(device)
    model.eval()

    with torch.no_grad():
        samples, meta = data_loaders["retrieval"]

        # To use the nan-checks safely, we need make temporary copies of the data
        disable_nan_checks = config._config["disable_nan_checks"]
        with ctxt_mgr(samples, device, disable_nan_checks) as valid:
            output = model(**valid)

        sims = output["cross_view_conf_matrix"].data.cpu().float().numpy()
        dataset = data_loaders.dataset_name
        nested_metrics = {}
        for metric in metrics:
            metric_name = metric.__name__
            res = metric(sims, query_masks=meta["query_masks"])
            verbose(epoch=0, metrics=res, name=dataset, mode=metric_name)
            if trainer is not None:
                if not trainer.mini_train:
                    trainer.writer.set_step(step=0, mode="val")
                # avoid tensboard folding by prefixing
                metric_name_ = f"test_{metric_name}"
                trainer.log_metrics(res, metric_name=metric_name_, mode="val")
            nested_metrics[metric_name] = res

    if data_loaders.num_test_captions == 1:
        visualizer.visualize_ranking(
            sims=sims,
            meta=meta,
            epoch=0,
            nested_metrics=nested_metrics,
        )
    log = {}
    for subkey, subval in nested_metrics.items():
        for subsubkey, subsubval in subval.items():
            log[f"test_{subkey}_{subsubkey}"] = subsubval
    for key, value in log.items():
        logger.info(" {:15s}: {}".format(str(key), value))
Пример #7
0
def run_exp(config):
    
    warnings.filterwarnings('ignore')
    logger = config.get_logger('train')

    leaderboard_path = config._args.leaderboard
    Path(leaderboard_path).parent.mkdir(exist_ok=True, parents=True)
    with open(leaderboard_path, 'a') as f:
        txt_path = f"{config._log_dir}/preds.txt"
        print(txt_path, file=f, flush=True)

    expert_dims, raw_input_dims = compute_dims(config, logger)
    trn_config = compute_trn_config(config)

    if config._args.group_seed:
        seeds = [int(config._args.group_seed)]
    else:
        seeds = [int(x) for x in config._args.seeds.split(",")]

    # set up local filesystem on the cluster
    if socket.gethostname().endswith("cluster"):
        os.system(str(Path.home() / "configure_tmp_data.sh"))

    for ii, seed in enumerate(seeds):
        tic = time.time()
        logger.info(f"{ii + 1}/{len(seeds)} Setting experiment random seed to {seed}")
        set_seeds(seed)
        config["seed"] = seed

        # We use cls defaults for backwards compatibility with the MMIT configs.  In the
        # long run this should be handled by the json configs themselves
        cls_defaults = ["train", "val", "tiny", "challenge"]

        model = config.init(
            name='arch',
            module=module_arch,
            expert_dims=expert_dims,
            text_dim=config["experts"]["text_dim"],
            disable_nan_checks=config["disable_nan_checks"],
            spatial_feats=config["data_loader"]["args"].get("spatial_feats", False),
            task=config.get("task", "retrieval"),
            ce_shared_dim=config["experts"].get("ce_shared_dim", None),
            feat_aggregation=config["data_loader"]["args"]["feat_aggregation"],
            trn_config=trn_config,
            trn_cat=config["data_loader"]["args"].get("trn_cat", 0),
        )
        logger.info(model)

        data_loaders = config.init(
            name='data_loader',
            module=module_data,
            logger=logger,
            raw_input_dims=raw_input_dims,
            text_feat=config["experts"]["text_feat"],
            text_dim=config["experts"]["text_dim"],
            text_agg=config["experts"]["text_agg"],
            use_zeros_for_missing=config["experts"].get("use_zeros_for_missing", False),
            task=config.get("task", "retrieval"),
            cls_partitions=config.get("cls_partitions", cls_defaults)
        )

        if config.get("manual_linear_init", False):
            logger.info("manually setting init for linear layers")
            def init_weights(m):
                if isinstance(m, nn.Linear):
                    torch.nn.init.xavier_uniform(m.weight)
                    m.bias.data.fill_(0.01)
            model.apply(init_weights)

        loss = config.init(name="loss", module=module_loss)
        metrics = [getattr(module_metric, met) for met in config['metrics']]
        trainable_params = filter(lambda p: p.requires_grad, model.parameters())

        if config["optimizer"]["type"] == "RAdam":
            optimizer = config.init('optimizer', radam, trainable_params)
        elif config["optimizer"]["type"] == "Ranger":
            optimizer = config.init('optimizer', ranger, trainable_params)
        elif config["optimizer"]["type"] == "SWATS":
            optimizer = config.init('optimizer', swats, trainable_params)
        else:
            optimizer = config.init('optimizer', torch.optim, trainable_params)

        if config["lr_scheduler"]["type"] == "StepLR":
            lr_scheduler = config.init('lr_scheduler', torch.optim.lr_scheduler,
                                       optimizer)
        else:
            lr_scheduler = config.init('lr_scheduler', cos_restart, optimizer)

        visualizer = config.init(
            name='visualizer',
            module=module_vis,
            exp_name=config._exper_name,
            web_dir=config._web_log_dir,
        )

        trainer = Trainer(
            model,
            loss,
            metrics,
            optimizer,
            config=config,
            data_loaders=data_loaders,
            lr_scheduler=lr_scheduler,
            mini_train=config._args.mini_train,
            disable_nan_checks=config["disable_nan_checks"],
            visualizer=visualizer,
            val_freq=config["trainer"].get("val_freq", 1),
            force_cpu_val=config.get("force_cpu_val", False),
            skip_first_n_saves=config["trainer"].get("skip_first_n_saves", 0),
            include_optim_in_ckpts=config["trainer"].get("include_optim_in_ckpts", 1),
            cache_targets=set(config.get("cache_targets", [])),
        )
        trainer.train()
        best_ckpt_path = config.save_dir / "trained_model.pth"
        duration = time.strftime('%Hh%Mm%Ss', time.gmtime(time.time() - tic))
        logger.info(f"Training took {duration}")

        if config._config.get("eval_settings", False):
            eval_config = copy.deepcopy(config)
            merge(eval_config._config, config["eval_settings"], strategy=Strategy.REPLACE)
            eval_config._args.resume = best_ckpt_path
            evaluation(eval_config, logger=logger, trainer=trainer)

    # If multiple runs were conducted, report relevant statistics
    if len(seeds) > 1:
        log_summary(
            logger=logger,
            log_path=config.log_path,
            eval_mode=config["eval_mode"],
            fixed_num_epochs=config["trainer"]["epochs"],
        )
    print(f"Log file stored at {config.log_path}")

    # Report the location of the "best" checkpoint of the final seeded run (here
    # "best" corresponds to the model with the highest geometric mean over the
    # R@1, R@5 and R@10 metrics when a validation set is used, or simply the final
    # epoch of training for fixed-length schedules).
    print(f"The best performing ckpt can be found at {str(best_ckpt_path)}")
Пример #8
0
def main(config):
    logger = config.get_logger('train')
    expert_dims, raw_input_dims = compute_dims(config, logger)
    seeds = [int(x) for x in config._args.seeds.split(",")]

    for seed in seeds:
        # Set the random initial seeds
        tic = time.time()
        logger.info(f"Setting experiment random seed to {seed}")
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

        data_loaders = config.init(
            name='data_loader',
            module=module_data,
            raw_input_dims=raw_input_dims,
            text_feat=config["experts"]["text_feat"],
            text_dim=config["experts"]["text_dim"],
        )

        model = config.init(
            name='arch',
            module=module_arch,
            expert_dims=expert_dims,
            text_dim=config["experts"]["text_dim"],
            disable_nan_checks=config["disable_nan_checks"],
        )
        logger.info(model)

        loss = config.init(name="loss", module=module_loss)
        metrics = [getattr(module_metric, met) for met in config['metrics']]
        trainable_params = filter(lambda p: p.requires_grad,
                                  model.parameters())

        optimizer = config.init('optimizer', torch.optim, trainable_params)
        lr_scheduler = config.init('lr_scheduler', torch.optim.lr_scheduler,
                                   optimizer)
        visualizer = config.init(
            name='visualizer',
            module=module_vis,
            exp_name=config._exper_name,
            log_dir=config._web_log_dir,
        )

        trainer = Trainer(
            model,
            loss,
            metrics,
            optimizer,
            config=config,
            data_loaders=data_loaders,
            lr_scheduler=lr_scheduler,
            mini_train=config._args.mini_train,
            disable_nan_checks=config["disable_nan_checks"],
            visualizer=visualizer,
            skip_first_n_saves=config["trainer"].get("skip_first_n_saves", 0),
            include_optim_in_ckpts=config["trainer"].get(
                "include_optim_in_ckpts", False),
        )
        trainer.train()
        best_ckpt_path = config.save_dir / "trained_model.pth"
        duration = time.strftime('%Hh%Mm%Ss', time.gmtime(time.time() - tic))
        logger.info(f"Training took {duration}")

        # If the dataset supports separate validation/test splits, the training config
        # json should specify an `eval_config` entry with the path to the test
        # configuration
        if config._config.get("eval_config", False):
            eval_args = argparse.ArgumentParser()
            eval_args.add_argument("--config", default=config["eval_config"])
            eval_args.add_argument("--device", default=config._args.device)
            eval_args.add_argument("--resume", default=best_ckpt_path)
            eval_config = ConfigParser(eval_args, slave_mode=True)
            evaluation(eval_config, logger=logger)

    # If multiple runs were conducted, report relevant statistics
    if len(seeds) > 1:
        log_summary(
            logger=logger,
            log_path=config.log_path,
            eval_mode=config["eval_mode"],
            fixed_num_epochs=config["trainer"]["epochs"],
        )
    print(f"Log file stored at {config.log_path}")

    # Report the location of the "best" checkpoint of the final seeded run (here
    # "best" corresponds to the model with the highest geometric mean over the
    # R@1, R@5 and R@10 metrics when a validation set is used, or simply the final
    # epoch of training for fixed-length schedules).
    print(f"The best performing ckpt can be found at {str(best_ckpt_path)}")
Пример #9
0
def test(config):
    logger = config.get_logger('test')
    logger.info("Running test with configuration:")
    logger.info(config)

    expert_dims = compute_dims(config)

    vocab = None
    vocab_size = None
    we_parameter = None

    if "attr" in config['experts']['modalities']:
        attr_vocab = Vocabulary()
        attr_vocab.load(
            os.path.join(config['data_loader']['args']['data_dir'],
                         'attributes/dict.attr.json'))
        attr_vocab_size = len(attr_vocab)
    else:
        attr_vocab = None
        attr_vocab_size = None

    data_loaders = config.init(
        name='data_loader',
        module=module_data,
        expert_dims=expert_dims,
        text_feat=config['experts']['text_feat'],
        text_dim=config['experts']['text_dim'],
    )

    model = config.init(name='arch',
                        module=module_arch,
                        expert_dims=expert_dims,
                        text_dim=config['experts']['text_dim'],
                        same_dim=config['experts']['ce_shared_dim'],
                        text_feat=config['experts']['text_feat'])
    trainer = TrainerJoint(
        model,
        loss=None,
        optimizer=None,
        config=config,
        data_loaders=data_loaders,
        lr_scheduler=None,
    )

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f"Running test on {device}")

    metric = trainer._valid_epoch(save_textatt=True)

    if config._args.mode == 'val':
        for key, value in metric.items():
            if key == 'recall_avg':
                logger.info(f'[Avg Recall]     : {value}')
            elif key == 'recall_avg_corr':
                logger.info(f'[Avg Recall corr]: {value}')
            elif key == 'comb_avg':
                logger.info(f'[comb_avg]       : {value}')
            elif key == 'recall':
                for i, category in zip(value, trainer.categories):
                    if len(i) == 2:
                        logger.info(f'[{category}] r@10, r@50: {i[0]}\t{i[1]}')
                    elif len(i) == 4:
                        logger.info(
                            f'[{category}] comp corr r@10, r@50: {i[0]}\t{i[1]}\t{i[2]}\t{i[3]}'
                        )
            elif key == 'comb':
                combstr = "comb:"
                for i, category in zip(value, trainer.categories):
                    combstr += f' {i[0]} {i[1]}'
                logger.info(combstr)
    else:
        save_fname = config.save_dir / f'test_score.pt'
        tic = time.time()
        logger.info("Saving score matrix: {} ...".format(save_fname))
        torch.save(metric, save_fname)
        logger.info(f"Done in {time.time() - tic:.3f}s")
Пример #10
0
def evaluation(config, logger=None):

    if logger is None:
        logger = config.get_logger('test')

    logger.info("Running evaluation with configuration:")
    logger.info(config)

    expert_dims, raw_input_dims = compute_dims(config)

    # Set the random initial seeds
    seed = config["seed"]
    logger.info(f"Setting experiment random seed to {seed}")
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    data_loaders = config.init(
        name='data_loader',
        module=module_data,
        raw_input_dims=raw_input_dims,
        text_feat=config["experts"]["text_feat"],
        text_dim=config["experts"]["text_dim"],
    )
    model = config.init(
        name='arch',
        module=module_arch,
        expert_dims=expert_dims,
        text_dim=config["experts"]["text_dim"],
        disable_nan_checks=config["disable_nan_checks"],
    )
    logger.info(model)

    metrics = [getattr(module_metric, met) for met in config['metrics']]
    visualizer = config.init(
        name='visualizer',
        module=module_vis,
        exp_name=config._exper_name,
        log_dir=config._web_log_dir,
    )
    ckpt_path = config._args.resume
    logger.info(f"Loading checkpoint: {ckpt_path} ...")
    checkpoint = torch.load(ckpt_path)
    state_dict = checkpoint['state_dict']
    if config['n_gpu'] > 1:
        model = torch.nn.DataParallel(model)
    model.load_state_dict(state_dict)

    # prepare model for testing.  Note that some datasets fail to fit the retrieval
    # set on the GPU, so we run them on the CPU
    if torch.cuda.is_available() and not config.get("disable_gpu", False):
        device = "cuda"
    else:
        device = "cpu"
    logger.info(f"Running evaluation on {device}")

    model = model.to(device)
    model.eval()

    with torch.no_grad():
        samples, meta = data_loaders["retrieval"]

        # To use the nan-checks safely, we need make temporary copies of the data
        disable_nan_checks = config._config["disable_nan_checks"]
        with valid_samples(samples, device, disable_nan_checks) as valid:
            output = model(**valid)

        sims = output["cross_view_conf_matrix"].data.cpu().float().numpy()
        dataset = data_loaders.dataset_name
        nested_metrics = {}
        for metric in metrics:
            metric_name = metric.__name__
            res = metric(sims, query_masks=meta["query_masks"])
            verbose(epoch=0, metrics=res, name=dataset, mode=metric_name)
            nested_metrics[metric_name] = res

    if data_loaders.num_test_captions == 1:
        visualizer.visualize_ranking(
            sims=sims,
            meta=meta,
            epoch=0,
            nested_metrics=nested_metrics,
        )
    log = {}
    for subkey, subval in nested_metrics.items():
        for subsubkey, subsubval in subval.items():
            log[f"test_{subkey}_{subsubkey}"] = subsubval
    for key, value in log.items():
        logger.info(' {:15s}: {}'.format(str(key), value))