예제 #1
0
def main(exp_const,data_const,model_const):
    io.mkdir_if_not_exists(exp_const.exp_dir,recursive=True)
    io.mkdir_if_not_exists(exp_const.log_dir)
    io.mkdir_if_not_exists(exp_const.model_dir)
    configure(exp_const.log_dir)
    save_constants({
        'exp': exp_const,
        'data': data_const,
        'model': model_const},
        exp_const.exp_dir)

    print('Creating network ...')
    model = Model()
    model.const = model_const
    model.encoder = Encoder(model.const.encoder).cuda()
    model.decoder = Decoder(model.const.decoder).cuda()

    encoder_path = os.path.join(
        exp_const.model_dir,
        f'encoder_{-1}')
    torch.save(model.encoder.state_dict(),encoder_path)

    decoder_path = os.path.join(
        exp_const.model_dir,
        f'decoder_{-1}')
    torch.save(model.decoder.state_dict(),decoder_path)

    print('Creating dataloader ...')
    dataset = VisualFeaturesDataset(data_const)
    dataloader = DataLoader(
        dataset,
        batch_size=exp_const.batch_size,
        shuffle=True)

    train_model(model,dataloader,exp_const)
예제 #2
0
def main(exp_const, data_const):
    print(f'Creating directory {exp_const.exp_dir} ...')
    io.mkdir_if_not_exists(exp_const.exp_dir, recursive=True)

    print('Saving constants ...')
    save_constants({'exp': exp_const, 'data': data_const}, exp_const.exp_dir)

    print('Loading data ...')
    img_id_to_obj_id = io.load_json_object(
        data_const.image_id_to_object_id_json)
    object_annos = io.load_json_object(data_const.object_annos_json)

    cooccur = {}
    for img_id, obj_ids in tqdm(img_id_to_obj_id.items()):
        synset_list = create_synset_list(object_annos, obj_ids)
        for synset1 in synset_list:
            for synset2 in synset_list:
                if synset1 not in cooccur:
                    cooccur[synset1] = {}

                if synset2 not in cooccur[synset1]:
                    cooccur[synset1][synset2] = 0

                cooccur[synset1][synset2] += 1

    synset_cooccur_json = os.path.join(exp_const.exp_dir,
                                       'synset_cooccur.json')
    io.dump_json_object(cooccur, synset_cooccur_json)
예제 #3
0
def main(exp_const, data_const, model_const):
    io.mkdir_if_not_exists(exp_const.exp_dir, recursive=True)
    io.mkdir_if_not_exists(exp_const.log_dir)
    io.mkdir_if_not_exists(exp_const.model_dir)
    configure(exp_const.log_dir)
    save_constants({
        'exp': exp_const,
        'data': data_const,
        'model': model_const
    }, exp_const.exp_dir)

    print('Creating model ...')
    model = Model()
    model.const = model_const
    model.concat_svm = ConcatSVM(model_const.concat_svm).cuda()
    model.to_txt(exp_const.exp_dir, single_file=True)

    print('Creating train data loader ...')
    train_data_const = copy.deepcopy(data_const)
    train_data_const.subset = 'train'
    train_data_loader = DataLoader(SemEval201810Dataset(train_data_const),
                                   batch_size=exp_const.batch_size,
                                   shuffle=True)

    print('Creating val data loader ...')
    val_data_const = copy.deepcopy(data_const)
    val_data_const.subset = 'val'
    val_data_loader = DataLoader(SemEval201810Dataset(val_data_const),
                                 batch_size=exp_const.batch_size,
                                 shuffle=False)

    print('Begin training ...')
    train_model(model, train_data_loader, val_data_loader, exp_const)
예제 #4
0
def main(exp_const,
         data_const_train,
         data_const_val,
         model_const,
         data_sign='hico'):
    io.mkdir_if_not_exists(exp_const.exp_dir, recursive=True)
    io.mkdir_if_not_exists(exp_const.log_dir)
    io.mkdir_if_not_exists(exp_const.model_dir)
    configure(exp_const.log_dir)
    save_constants(
        {
            'exp': exp_const,
            'data_train': data_const_train,
            'data_val': data_const_val,
            'model': model_const
        }, exp_const.exp_dir)

    print('Creating model ...')
    model = Model()
    model.const = model_const
    model.hoi_classifier = HoiClassifier(model.const.hoi_classifier,
                                         data_sign).cuda()
    model.to_txt(exp_const.exp_dir, single_file=True)

    print('Creating data loaders ...')
    dataset_train = Features(data_const_train)
    dataset_val = Features(data_const_val)

    train_model(model, dataset_train, dataset_val, exp_const)
예제 #5
0
def main(exp_const, data_const, model_const):
    io.mkdir_if_not_exists(exp_const.exp_dir, recursive=True)
    io.mkdir_if_not_exists(exp_const.log_dir)
    io.mkdir_if_not_exists(exp_const.model_dir)
    io.mkdir_if_not_exists(exp_const.vis_dir)
    configure(exp_const.log_dir)
    save_constants({
        'exp': exp_const,
        'data': data_const,
        'model': model_const
    }, exp_const.exp_dir)

    print('Creating network ...')
    model = Model()
    model.const = model_const
    model.net = NET(model.const.net)
    if model.const.model_num is not None:
        model.net.load_state_dict(torch.load(model.const.net_path))
    model.net.cuda()
    model.img_mean = np.array([0.485, 0.456, 0.406])
    model.img_std = np.array([0.229, 0.224, 0.225])
    model.to_file(os.path.join(exp_const.exp_dir, 'model.txt'))

    print('Creating dataloader ...')
    dataloaders = {}
    for mode, subset in exp_const.subset.items():
        data_const = copy.deepcopy(data_const)
        data_const.subset = subset
        dataset = DATASET(data_const)
        dataloaders[mode] = DataLoader(dataset,
                                       batch_size=exp_const.batch_size,
                                       shuffle=True,
                                       num_workers=exp_const.num_workers)

    train_model(model, dataloaders, exp_const)
def main(exp_const, data_const):
    io.mkdir_if_not_exists(exp_const.exp_dir, recursive=True)
    save_constants({'exp': exp_const, 'data': data_const}, exp_const.exp_dir)

    print('Loading glove embeddings ...')
    glove_idx = io.load_json_object(data_const.glove_idx)
    glove_h5py = h5py.File(data_const.glove_h5py, 'r')
    glove_embeddings = glove_h5py['embeddings'][()]
    num_glove_words, glove_dim = glove_embeddings.shape
    print('-' * 80)
    print(f'number of glove words: {num_glove_words}')
    print(f'glove dim: {glove_dim}')
    print('-' * 80)

    print('Loading visual features ...')
    visual_features_idx = io.load_json_object(data_const.visual_features_idx)
    visual_features_h5py = h5py.File(data_const.visual_features_h5py, 'r')
    visual_features = visual_features_h5py['features'][()]
    num_visual_features, visual_features_dim = visual_features.shape
    print('-' * 80)
    print(f'number of visual features: {num_visual_features}')
    print(f'visual feature dim: {visual_features_dim}')
    print('-' * 80)

    print('Combining glove with visual features ...')
    visual_word_vecs_idx_json = os.path.join(exp_const.exp_dir,
                                             'visual_word_vecs_idx.json')
    io.dump_json_object(glove_idx, visual_word_vecs_idx_json)
    visual_word_vecs_h5py = h5py.File(
        os.path.join(exp_const.exp_dir, 'visual_word_vecs.h5py'), 'w')
    visual_word_vec_dim = glove_dim + visual_features_dim
    visual_word_vecs = np.zeros([num_glove_words, visual_word_vec_dim])
    mean_visual_feature = visual_features_h5py['mean'][()]
    for word in tqdm(glove_idx.keys()):
        glove_id = glove_idx[word]
        glove_vec = glove_embeddings[glove_id]
        if word in visual_features_idx:
            feature_id = visual_features_idx[word]
            feature = visual_features[feature_id]
        else:
            feature = mean_visual_feature
        visual_word_vec = np.concatenate(
            (glove_vec, (feature - mean_visual_feature)))
        # visual_word_vec = np.concatenate((
        #     normalize(glove_vec),
        #     normalize(feature)))
        visual_word_vecs[glove_id] = visual_word_vec

    visual_word_vecs_h5py.create_dataset('embeddings',
                                         data=visual_word_vecs,
                                         chunks=(1, visual_word_vec_dim))
    visual_word_vecs_h5py.close()
예제 #7
0
def main(exp_const, data_const, model_const):
    io.mkdir_if_not_exists(exp_const.exp_dir, recursive=True)
    io.mkdir_if_not_exists(exp_const.log_dir)
    io.mkdir_if_not_exists(exp_const.model_dir)
    io.mkdir_if_not_exists(exp_const.vis_dir)
    configure(exp_const.log_dir)
    if model_const.model_num is None:
        const_dict = {
            'exp': exp_const,
            'data': data_const,
            'model': model_const
        }
    else:
        const_dict = {
            f'exp_finetune_{model_const.model_num}': exp_const,
            f'data_finetune_{model_const.model_num}': data_const,
            f'model_finetune_{model_const.model_num}': model_const
        }
    save_constants(const_dict, exp_const.exp_dir)

    print('Creating network ...')
    model = Model()
    model.const = model_const
    model.net = LogBilinear(model.const.net)
    if model.const.model_num is not None:
        model.net.load_state_dict(torch.load(model.const.net_path))
    model.net.cuda()
    model.to_file(os.path.join(exp_const.exp_dir, 'model.txt'))

    print('Creating positive dataloader ...')
    dataset = MultiSenseCooccurDataset(data_const)
    collate_fn = dataset.create_collate_fn()
    dataloader = DataLoader(dataset,
                            batch_size=exp_const.batch_size,
                            shuffle=True,
                            num_workers=exp_const.num_workers,
                            collate_fn=collate_fn)

    print('Creating negative dataloader ...')
    neg_dataset = NegMultiSenseCooccurDataset(data_const)
    collate_fn = neg_dataset.create_collate_fn()
    neg_dataloader = DataLoader(neg_dataset,
                                batch_size=exp_const.batch_size,
                                shuffle=True,
                                num_workers=exp_const.num_workers,
                                collate_fn=collate_fn)

    err_msg = f'Num words mismatch (try {len(dataset.words)})'
    assert (len(dataset.words) == model.const.net.num_words), err_msg

    train_model(model, dataloader, neg_dataloader, exp_const)
def generate(exp_const, data_const, data_sign):
    print(f'Creating exp_dir: {exp_const.exp_dir}')
    io.mkdir_if_not_exists(exp_const.exp_dir)

    save_constants({'exp': exp_const, 'data': data_const}, exp_const.exp_dir)

    print(f'Reading split_ids.json ...')
    split_ids = io.load_json_object(data_const.split_ids_json)

    print('Creating an object-detector-only HOI detector ...')
    hoi_cand_gen = HoiCandidatesGenerator(data_const, data_sign)

    print(f'Creating a hoi_candidates_{exp_const.subset}.hdf5 file ...')
    hoi_cand_hdf5 = os.path.join(exp_const.exp_dir,
                                 f'hoi_candidates_{exp_const.subset}.hdf5')
    f = h5py.File(hoi_cand_hdf5, 'w')

    # 从Faster RCNN的所有预测结果中选择的高分预测
    print('Reading selected dets from hdf5 file ...')
    all_selected_dets = h5py.File(data_const.selected_dets_hdf5, 'r')

    for global_id in tqdm(split_ids[exp_const.subset]):
        selected_dets = {
            'boxes': {},
            'scores': {},
            'rpn_ids': {},
            'obj_cls': {}
        }
        start_end_ids = all_selected_dets[global_id]['start_end_ids'][()]
        boxes_scores_rpn_ids = \
            all_selected_dets[global_id]['boxes_scores_rpn_ids'][()]

        for cls_ind, cls_name in enumerate(COCO_CLASSES):
            start_id, end_id = start_end_ids[cls_ind]
            boxes = boxes_scores_rpn_ids[start_id:end_id, :4]
            scores = boxes_scores_rpn_ids[start_id:end_id, 4]
            rpn_ids = boxes_scores_rpn_ids[start_id:end_id, 5]
            object_cls = np.full((end_id - start_id, ), cls_ind)
            selected_dets['boxes'][cls_name] = boxes
            selected_dets['scores'][cls_name] = scores
            selected_dets['rpn_ids'][cls_name] = rpn_ids
            selected_dets['obj_cls'][cls_name] = object_cls

        pred_dets, start_end_ids = hoi_cand_gen.predict(selected_dets)
        f.create_group(global_id)
        f[global_id].create_dataset('boxes_scores_rpn_ids_hoi_idx',
                                    data=pred_dets)
        f[global_id].create_dataset('start_end_ids', data=start_end_ids)

    f.close()
def assign(exp_const, data_const):
    io.mkdir_if_not_exists(exp_const.exp_dir)

    print('Saving constants ...')
    save_constants({'exp': exp_const, 'data':data_const}, exp_const.exp_dir)

    print(f'Reading hoi_candidates_{exp_const.subset}.hdf5 ...')
    hoi_cand_hdf5 = h5py.File(data_const.hoi_cand_hdf5, 'r')

    print(f'Creating hoi_candidate_labels_{exp_const.subset}.hdf5 ...')
    filename = os.path.join(
        exp_const.exp_dir,
        f'hoi_candidate_labels_{exp_const.subset}.hdf5')
    hoi_cand_label_hdf5 = h5py.File(filename, 'w')

    print('Loading gt hoi detections ...')
    split_ids = io.load_json_object(data_const.split_ids_json)
    global_ids = split_ids[exp_const.subset]
    gt_dets = load_gt_dets(data_const.anno_list_json, global_ids)

    print('Loading hoi_list.json ...')
    hoi_list = io.load_json_object(data_const.hoi_list_json)
    hoi_ids = [hoi['id'] for hoi in hoi_list]

    for global_id in tqdm(global_ids):
        boxes_scores_rpn_ids_hoi_idx = \
            hoi_cand_hdf5[global_id]['boxes_scores_rpn_ids_hoi_idx']
        start_end_ids = hoi_cand_hdf5[global_id]['start_end_ids']
        num_cand = boxes_scores_rpn_ids_hoi_idx.shape[0]
        labels = np.zeros([num_cand])
        for hoi_id in gt_dets[global_id]:
            start_id, end_id = start_end_ids[int(hoi_id)-1]
            for i in range(start_id, end_id):
                cand_det = {
                    'human_box': boxes_scores_rpn_ids_hoi_idx[i, :4],
                    'object_box': boxes_scores_rpn_ids_hoi_idx[i, 4:8],
                }
                # 查看检测结果ho候选对与gt中ho候选对的匹配情况,如果有所匹配(iou>0.5)则label置1
                is_match = match_hoi(cand_det, gt_dets[global_id][hoi_id])
                if is_match:
                    labels[i] = 1.0

        hoi_cand_label_hdf5.create_dataset(global_id, data=labels)

    hoi_cand_label_hdf5.close()
예제 #10
0
def main(exp_const, data_const):
    print(f'Creating directory {exp_const.exp_dir} ...')
    io.mkdir_if_not_exists(exp_const.exp_dir, recursive=True)

    print('Saving constants ...')
    save_constants({'exp': exp_const, 'data': data_const}, exp_const.exp_dir)

    print('Creating dataloader ...')
    data_const = copy.deepcopy(data_const)
    dataset = ImagenetNoImgsDataset(data_const)
    collate_fn = dataset.create_collate_fn()
    dataloader = DataLoader(dataset,
                            batch_size=exp_const.batch_size,
                            shuffle=False,
                            num_workers=exp_const.num_workers,
                            collate_fn=collate_fn)

    create_gt_synset_cooccur(exp_const, dataloader)
예제 #11
0
def main(exp_const, data_const, model_const):
    io.mkdir_if_not_exists(exp_const.exp_dir, recursive=True)
    io.mkdir_if_not_exists(exp_const.log_dir)
    io.mkdir_if_not_exists(exp_const.model_dir)
    io.mkdir_if_not_exists(exp_const.vis_dir)
    configure(exp_const.log_dir)
    save_constants({
        'exp': exp_const,
        'data': data_const,
        'model': model_const
    }, exp_const.exp_dir)

    print('Creating network ...')
    model = Model()
    model.const = model_const
    model.net = ResnetModel(model.const.net)
    model.embed2class = Embed2Class(model.const.embed2class)
    if model.const.model_num is not None:
        model.net.load_state_dict(torch.load(model.const.net_path))
        model.embed2class.load_state_dict(
            torch.load(model.const.embed2class_path))
    model.net.cuda()
    model.embed2class.cuda()
    model.img_mean = np.array([0.485, 0.456, 0.406])
    model.img_std = np.array([0.229, 0.224, 0.225])
    model.to_file(os.path.join(exp_const.exp_dir, 'model.txt'))

    print('Creating dataloader ...')
    dataloaders = {}
    for mode, subset in exp_const.subset.items():
        data_const = copy.deepcopy(data_const)
        if subset == 'train':
            data_const.train = True
        else:
            data_const.train = False
        dataset = Cifar100Dataset(data_const)
        collate_fn = dataset.get_collate_fn()
        dataloaders[mode] = DataLoader(dataset,
                                       batch_size=exp_const.batch_size,
                                       shuffle=True,
                                       num_workers=exp_const.num_workers,
                                       collate_fn=collate_fn)

    train_model(model, dataloaders, exp_const)
예제 #12
0
def main(exp_const, data_const):
    io.mkdir_if_not_exists(exp_const.exp_dir, recursive=True)
    save_constants({'exp': exp_const, 'data': data_const}, exp_const.exp_dir)

    print('Loading glove embeddings ...')
    glove_idx = io.load_json_object(data_const.glove_idx)
    glove_h5py = h5py.File(data_const.glove_h5py, 'r')
    glove_embeddings = glove_h5py['embeddings'][()]
    num_glove_words, glove_dim = glove_embeddings.shape
    print('-' * 80)
    print(f'number of glove words: {num_glove_words}')
    print(f'glove dim: {glove_dim}')
    print('-' * 80)

    random = 2 * (np.random.rand(num_glove_words, exp_const.random_dim) - 0.5)
    word_vecs = np.concatenate((glove_embeddings, random), 1)
    word_vec_dim = glove_dim + exp_const.random_dim
    word_vecs_h5py = h5py.File(
        os.path.join(exp_const.exp_dir, 'glove_random_word_vecs.h5py'), 'w')
    word_vecs_h5py.create_dataset('embeddings',
                                  data=word_vecs,
                                  chunks=(1, word_vec_dim))
예제 #13
0
def main(exp_const, data_const):
    nltk.download('averaged_perceptron_tagger')

    io.mkdir_if_not_exists(exp_const.exp_dir, recursive=True)
    save_constants({'exp': exp_const, 'data': data_const}, exp_const.exp_dir)

    print('Loading glove embeddings ...')
    glove_idx = io.load_json_object(data_const.glove_idx)
    glove_h5py = h5py.File(data_const.glove_h5py, 'r')
    glove_embeddings = glove_h5py['embeddings'][()]
    num_glove_words, glove_dim = glove_embeddings.shape
    print('-' * 80)
    print(f'number of glove words: {num_glove_words}')
    print(f'glove dim: {glove_dim}')
    print('-' * 80)

    print('Loading visual embeddings ...')
    visual_embeddings = np.load(data_const.visual_embeddings_npy)
    visual_word_to_idx = io.load_json_object(data_const.visual_word_to_idx)
    # io.dump_json_object(
    #     list(visual_word_to_idx.keys()),
    #     os.path.join(exp_const.exp_dir,'visual_words.json'))
    mean_visual_embedding = np.mean(visual_embeddings, 0)
    num_visual_words, visual_embed_dim = visual_embeddings.shape
    print('-' * 80)
    print(f'number of visual embeddings: {num_visual_words}')
    print(f'visual embedding dim: {visual_embed_dim}')
    print('-' * 80)

    print('Combining glove with visual embeddings ...')
    visual_word_vecs_idx_json = os.path.join(exp_const.exp_dir,
                                             'visual_word_vecs_idx.json')
    io.dump_json_object(glove_idx, visual_word_vecs_idx_json)
    visual_word_vecs_h5py = h5py.File(
        os.path.join(exp_const.exp_dir, 'visual_word_vecs.h5py'), 'w')
    visual_word_vec_dim = glove_dim + visual_embed_dim
    visual_word_vecs = np.zeros([num_glove_words, visual_word_vec_dim])
    visual_words = set()
    lemmatizer = Lemmatizer()
    for word in tqdm(glove_idx.keys()):
        glove_id = glove_idx[word]
        glove_vec = glove_embeddings[glove_id]

        if word in visual_word_to_idx:
            idx = visual_word_to_idx[word]
            visual_embedding = visual_embeddings[idx]
            visual_words.add(word)
        else:
            lemma = lemmatizer.lemmatize(word)
            if lemma in visual_word_to_idx:
                idx = visual_word_to_idx[lemma]
                visual_embedding = visual_embeddings[idx]
                visual_words.add(lemma)
                visual_words.add(word)
            else:
                visual_embedding = mean_visual_embedding

        visual_word_vec = np.concatenate((glove_vec, visual_embedding))
        visual_word_vecs[glove_id] = visual_word_vec

    visual_word_vecs_h5py.create_dataset('embeddings',
                                         data=visual_word_vecs,
                                         chunks=(1, visual_word_vec_dim))
    visual_word_vecs_h5py.close()

    io.dump_json_object(list(visual_words),
                        os.path.join(exp_const.exp_dir, 'visual_words.json'))
예제 #14
0
def main(exp_const, data_const):
    print(f'Creating directory {exp_const.exp_dir} ...')
    io.mkdir_if_not_exists(exp_const.exp_dir, recursive=True)

    print('Saving constants ...')
    save_constants({'exp': exp_const, 'data': data_const}, exp_const.exp_dir)

    print('Loading cooccur ...')
    num_cooccur_types = len(data_const.cooccur_paths)
    merged_cooccur = {}
    for i, (cooccur_type,cooccur_json) in \
        enumerate(data_const.cooccur_paths.items()):

        print(f'    Merging {cooccur_type} ...')
        cooccur = io.load_json_object(cooccur_json)
        for word1, context in tqdm(cooccur.items()):
            if word1 not in merged_cooccur:
                merged_cooccur[word1] = {}

            for word2, count in context.items():
                if word2 not in merged_cooccur[word1]:
                    merged_cooccur[word1][word2] = [0] * num_cooccur_types

                merged_cooccur[word1][word2][i] += count

    if exp_const.normalize == True:
        print('Normalizing by self counts ...')
        for word1, context in tqdm(merged_cooccur.items()):
            norm_counts = merged_cooccur[word1][word1]

            for word2, counts in context.items():
                if word2 == word1:
                    continue

                for i, norm_count in enumerate(norm_counts):
                    counts[i] = counts[i] / (norm_count + 1e-6)
                    counts[i] = min(1, counts[i])

            merged_cooccur[word1][word1] = [1] * len(norm_counts)

    pandas_cols = {
        'word1': [],
        'word2': [],
    }
    for cooccur_type in data_const.cooccur_paths.keys():
        pandas_cols[cooccur_type] = []

    print('Creating pandas columns ...')
    for word1, context in tqdm(merged_cooccur.items()):
        for word2, counts in context.items():
            pandas_cols['word1'].append(word1)
            pandas_cols['word2'].append(word2)
            for i, cooccur_type in enumerate(data_const.cooccur_paths.keys()):
                pandas_cols[cooccur_type].append(counts[i])

    pandas_cols['word1'] = pd.Categorical(pandas_cols['word1'])
    pandas_cols['word2'] = pd.Categorical(pandas_cols['word2'])

    for cooccur_type in data_const.cooccur_paths.keys():
        pandas_cols[cooccur_type] = pd.Series(pandas_cols[cooccur_type])

    df = pd.DataFrame(pandas_cols)

    print('Saving DataFrame to csv ...')
    df.to_csv(data_const.merged_cooccur_csv, index=False)
예제 #15
0
def main(exp_const,data_const,model_const):
    np.random.seed(exp_const.seed)
    torch.manual_seed(exp_const.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    io.mkdir_if_not_exists(exp_const.exp_dir,recursive=True)
    io.mkdir_if_not_exists(exp_const.log_dir)
    io.mkdir_if_not_exists(exp_const.model_dir)
    io.mkdir_if_not_exists(exp_const.vis_dir)
    
    tb_writer = SummaryWriter(log_dir=exp_const.log_dir)
    
    model_num = model_const.model_num
    save_constants({
        f'exp_{model_num}': exp_const,
        f'data_train_{model_num}': data_const['train'],
        f'data_val_{model_num}': data_const['val'],
        f'model_{model_num}': model_const},
        exp_const.exp_dir)
    
    print('Creating network ...')
    model = Constants()
    model.const = model_const
    model.object_encoder = ObjectEncoder(model.const.object_encoder)
    model.cap_encoder = CapEncoder(model.const.cap_encoder)
    if exp_const.random_lang is True:
        model.cap_encoder.random_init()

    c_dim = model.object_encoder.const.object_feature_dim
    if exp_const.contextualize==True:
        c_dim = model.object_encoder.const.context_layer.hidden_size
    model.self_sup_criterion = create_info_nce_criterion(
        model.object_encoder.const.object_feature_dim,
        c_dim,
        model.object_encoder.const.context_layer.hidden_size)
    
    o_dim = model.object_encoder.const.object_feature_dim
    if exp_const.contextualize==True:
        o_dim = model.object_encoder.const.context_layer.hidden_size
    
    model.lang_sup_criterion = create_cap_info_nce_criterion(
        o_dim,
        model.object_encoder.const.object_feature_dim,
        model.cap_encoder.model.config.hidden_size,
        model.cap_encoder.model.config.hidden_size//2,
        model.const.cap_info_nce_layers)
    if model.const.model_num != -1:
        model.object_encoder.load_state_dict(
            torch.load(model.const.object_encoder_path)['state_dict'])
        model.self_sup_criterion.load_state_dict(
            torch.load(model.const.self_sup_criterion_path)['state_dict'])
        model.lang_sup_criterion.load_state_dict(
            torch.load(model.const.lang_sup_criterion_path)['state_dict'])
    model.object_encoder.cuda()
    model.cap_encoder.cuda()
    model.self_sup_criterion.cuda()
    model.lang_sup_criterion.cuda()
    model.object_encoder.to_file(
        os.path.join(exp_const.exp_dir,'object_encoder.txt'))
    model.self_sup_criterion.to_file(
        os.path.join(exp_const.exp_dir,'self_supervised_criterion.txt'))
    model.lang_sup_criterion.to_file(
        os.path.join(exp_const.exp_dir,'lang_supervised_criterion.txt'))

    print('Creating dataloader ...')
    dataloaders = {}
    if exp_const.dataset=='coco':
        Dataset = CocoDataset
    elif exp_const.dataset=='flickr':
        Dataset = FlickrDataset
    else:
        msg = f'{exp_const.dataset} not implemented'
        raise NotImplementedError(msg)

    for mode, const in data_const.items():
        dataset = Dataset(const)
        
        if mode=='train':
            shuffle=True
            batch_size=exp_const.train_batch_size
        else:
            shuffle=True
            batch_size=exp_const.val_batch_size

        dataloaders[mode] = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=exp_const.num_workers)

    train_model(model,dataloaders,exp_const,tb_writer)