Exemple #1
0
    def __init__(self,
                 split='train',
                 lang_input='phrase',
                 neg_img=True,
                 neg_lang=True):
        data.Dataset.__init__(self)
        TextureDescriptionData.__init__(self, phid_format='str')
        self.split = split
        self.lang_input = lang_input
        self.neg_img = neg_img
        self.neg_lang = neg_lang
        self.img_transform = build_transforms(is_train=False)

        self.pos_pairs = list()
        for img_i, img_name in enumerate(self.img_splits[self.split]):
            img_data = self.img_data_dict[img_name]
            if self.lang_input == 'phrase':
                self.pos_pairs += [(img_i, ph)
                                   for ph in img_data['phrase_ids']]
            elif self.lang_input == 'description':
                self.pos_pairs += [
                    (img_i, desc_idx)
                    for desc_idx in range(len(img_data['descriptions']))
                ]
            else:
                raise NotImplementedError
        return
    def __init__(self, model=None, img_transform=None, device='cuda',
                 trained_path='output/triplet_match/c34_bert_l2_s_lr0.00001', model_file='BEST_checkpoint.pth',
                 split_to_phrases=False, dataset=None):
        if model is None:
            model, device = load_model(trained_path, model_file)
            model.eval()
        self.model = model
        self.device = device

        if img_transform is None:
            img_transform = build_transforms(is_train=False)
        self.img_transform = img_transform

        self.split_to_phrases = split_to_phrases
        if dataset is None:
            dataset = TextureDescriptionData(phid_format=None)
        self.dataset = dataset

        self.ph_vec_dict = None
        if self.split_to_phrases:
            ph_vecs = get_phrase_vecs(self.model, self.dataset)
            self.ph_vec_dict = {dataset.phrases[i]: ph_vecs[i] for i in range(len(dataset.phrases))}

        self.img_vecs = dict()
        return
    def __init__(self, split='train', is_train=True, cached_resnet_feats=None):
        data.Dataset.__init__(self)
        TextureDescriptionData.__init__(self, phid_format='set')

        self.split = split
        self.is_train = is_train
        self.cached_resnet_feats = cached_resnet_feats
        self.use_cache = self.cached_resnet_feats is not None and len(
            self.cached_resnet_feats) > 0
        self.transform = None
        if not self.use_cache:
            self.transform = build_transforms(is_train)
        print('PhraseClassifyDataset initialized.')
    def __init__(self, model=None, img_transform=None, dataset=None, device='cuda'):
        if model is None:
            model, device = load_model()
        self.model = model
        self.device = device

        if img_transform is None:
            img_transform = build_transforms(is_train=False)
        self.img_transform = img_transform

        if dataset is None:
            dataset = TextureDescriptionData(phid_format=None)
        self.dataset = dataset

        self.ph_vecs = get_phrase_vecs(self.model, self.dataset)
        return
    def __init__(self,
                 model=None,
                 img_transform=None,
                 device='cuda',
                 dataset=None):
        if dataset is None:
            dataset = TextureDescriptionData(phid_format=None)
        self.dataset = dataset

        if model is None:
            model, device = load_model(dataset=self.dataset)
            model.eval()
        self.model = model
        self.device = device

        if img_transform is None:
            img_transform = build_transforms(is_train=False)
        self.img_transform = img_transform

        self.img_ph_scores = dict()
        return
    def __init__(self):
        self.img_names = [
            'cobwebbed/cobwebbed_0088.jpg', 'lined/lined_0084.jpg',
            'polka-dotted/polka-dotted_0215.jpg', 'swirly/swirly_0135.jpg',
            'chequered/chequered_0103.jpg', 'honeycombed/honeycombed_0059.jpg',
            'dotted/dotted_0131.jpg', 'striped/striped_0035.jpg',
            'zigzagged/zigzagged_0064.jpg', 'banded/banded_0138.jpg'
        ]

        self.colors = {
            'white': (245, 245, 230),
            'black': (20, 20, 20),
            'brown': (120, 70, 20),
            'green': (30, 160, 30),
            'blue': (30, 100, 200),
            'red': (120, 30, 30),
            'yellow': (240, 240, 30),
            'pink': (240, 140, 190),
            'orange': (240, 140, 20),
            'gray': (100, 100, 100),
            'purple': (140, 30, 200)
        }
        self.color_names = list(self.colors.keys())
        # 'silver': (192, 192, 192)}
        self.is_fore_back = [
            True, True, True, True, False, False, True, False, False, False
        ]
        self.patterns = [
            'web', 'lines', 'polka-dots', 'swirls', 'squares', 'hexagon',
            'dots', 'stripes', 'zigzagged', 'banded'
        ]

        self.color_tuples = list()
        for i in range(len(self.color_names)):
            for j in range(len(self.color_names)):
                if i != j:
                    self.color_tuples.append((i, j))

        self.img_transform = build_transforms(is_train=False)
def main():
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, word_encoder
    word_encoder = WordEncoder()

    # Initialize / load checkpoint
    if checkpoint is None:
        print("Starting training from scratch!")
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_encoder.word_list),
                                       dropout=dropout)
        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
        encoder: Encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=encoder_lr) if fine_tune_encoder else None
        if init_word_emb == 'fast_text':
            word_emb = get_word_embed(word_encoder.word_list, 'fast_text')
            decoder.embedding.weight.data.copy_(torch.from_numpy(word_emb))

    else:
        print("Loading from checkpoint " + checkpoint)
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

    # Move to GPU, if available
    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    train_loader = torch.utils.data.DataLoader(CaptionDataset(
        split='train',
        is_train=True,
        word_encoder=word_encoder,
        transform=build_transforms(is_train=fine_tune_encoder)),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True,
                                               drop_last=True)
    val_loader = torch.utils.data.DataLoader(CaptionDataset(
        split='val',
        is_train=False,
        word_encoder=word_encoder,
        transform=build_transforms(is_train=False)),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=workers,
                                             pin_memory=True,
                                             collate_fn=caption_collate)

    print('ready to train.')
    # Epochs
    for epoch in range(start_epoch, epochs):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch)

        # One epoch's validation
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_dir = 'output/show_attend_tell/checkpoints'
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        save_checkpoint(checkpoint_base_name,
                        epoch,
                        epochs_since_improvement,
                        encoder,
                        decoder,
                        encoder_optimizer,
                        decoder_optimizer,
                        recent_bleu4,
                        is_best,
                        save_dir='output/show_attend_tell/checkpoints')
    def __init__(self,
                 split=None,
                 data_path='other_datasets/CUB_200_2011',
                 val_ratio=0.1):
        self.data_path = data_path
        self.split = split
        self.val_ratio = val_ratio
        if split not in ['train', 'val', 'test']:
            self.split = None
        self.img_transform = build_transforms(is_train=False)

        self.img_splits = {'train': list(), 'test': list()}
        with open(os.path.join(data_path, 'train_test_split.txt'), 'r') as f:
            for line in f:
                img_id, is_train = line.split(' ')
                img_id = int(img_id.strip()) - 1
                is_train = int(is_train.strip())
                if is_train:
                    self.img_splits['train'].append(img_id)
                else:
                    self.img_splits['test'].append(img_id)
        if val_ratio > 0:
            val_len = int(len(self.img_splits['train']) * val_ratio)
            np.random.seed(0)
            self.img_splits['val'] = np.random.choice(self.img_splits['train'],
                                                      val_len,
                                                      replace=False)
            train_ids = [
                i for i in self.img_splits['train']
                if i not in self.img_splits['val']
            ]
            self.img_splits['train'] = train_ids

        self.class_names = []
        with open(os.path.join(data_path, 'classes.txt'), 'r') as f:
            for li, line in enumerate(f):
                cls_id, name = line.split(' ')
                cls_id = int(cls_id.strip()) - 1
                assert cls_id == li
                name = name.strip()
                self.class_names.append(name)

        self.att_names = []
        with open(os.path.join(data_path, 'attributes/attributes.txt'),
                  'r') as f:
            for li, line in enumerate(f):
                att_id, name = line.split(' ')
                att_id = int(att_id.strip()) - 1
                assert att_id == li
                name = name.strip()
                self.att_names.append(name)

        self.att_types = dict()
        self.att_types['shape'] = [
            i for i, n in enumerate(self.att_names)
            if '_shape:' in n or 'length:' in n or 'size:' in n
        ]
        for type in ['color', 'pattern']:
            self.att_types[type] = [
                i for i, n in enumerate(self.att_names) if '_%s:' % type in n
            ]

        self.img_data_list = []
        with open(os.path.join(data_path, 'images.txt'), 'r') as f:
            for li, line in enumerate(f):
                img_id, name = line.split(' ')
                img_id = int(img_id.strip()) - 1
                assert img_id == li
                name = name.strip()
                self.img_data_list.append({'img_name': name})

        with open(os.path.join(data_path, 'image_class_labels.txt'), 'r') as f:
            for line in f:
                img_id, cls = line.split(' ')
                img_id = int(img_id.strip()) - 1
                cls = int(cls.strip()) - 1
                self.img_data_list[img_id]['class_label'] = cls

        with open(os.path.join(data_path, 'bounding_boxes.txt'), 'r') as f:
            for line in f:
                numbers = line.split(' ')
                numbers = [float(n.strip()) for n in numbers]
                img_id = int(numbers[0]) - 1
                xywh = numbers[1:]
                self.img_data_list[img_id]['box'] = xywh

        for img_data in self.img_data_list:
            img_data['att_labels'] = np.zeros(len(self.att_names))

        self.gt_att_labels = np.zeros(
            (len(self.img_data_list), len(self.att_names)))
        with open(
                os.path.join(data_path,
                             'attributes/image_attribute_labels.txt'),
                'r') as f:
            # <image_id> <attribute_id> <is_present> <certainty_id> <time>
            for line in f:
                numbers = line.split(' ')
                numbers = [int(n.strip()) for n in numbers[:3]]
                img_id = numbers[0] - 1
                att_id = numbers[1] - 1
                is_present = numbers[2]
                if is_present:
                    self.img_data_list[img_id]['att_labels'][att_id] = 1
                    self.gt_att_labels[img_id, att_id] = 1

        # TODO non-localized

        self.class_att_labels = np.zeros((200, 312))
        with open(
                os.path.join(
                    data_path,
                    'attributes/class_attribute_labels_continuous.txt'),
                'r') as f:
            # 200 lines and 312 space-separated columns
            for li, line in enumerate(f):
                numbers = line.split(' ')
                numbers = np.array([float(n.strip()) for n in numbers])
                self.class_att_labels[li] = numbers

        print('CUB dataset ready.')
        return
Exemple #9
0
def predict(beam_size, img_dataset=None, split=split):
    # DataLoader
    if img_dataset is None:
        img_dataset = ImgOnlyDataset(split=split, transform=build_transforms(is_train=False))
    loader = torch.utils.data.DataLoader(img_dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True)

    # TODO: Batched Beam Search
    # Therefore, do not use a batch_size greater than 1 - IMPORTANT!

    # Lists to store references (true captions), and hypothesis (prediction) for each image
    # If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
    # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]

    predictions = dict()
    # For each image
    with torch.no_grad():
        for bi, (img_name, image) in enumerate(tqdm(loader, desc="PREDICTING AT BEAM SIZE " + str(beam_size))):
            img_name = img_name[0]
            k = beam_size

            # Move to GPU device, if available
            image = image.to(device)  # (1, 3, 256, 256)

            # Encode
            encoder_out = encoder(image)  # (1, enc_image_size, enc_image_size, encoder_dim)
            enc_image_size = encoder_out.size(1)
            encoder_dim = encoder_out.size(3)

            # Flatten encoding
            encoder_out = encoder_out.view(1, -1, encoder_dim)  # (1, num_pixels, encoder_dim)
            num_pixels = encoder_out.size(1)

            # We'll treat the problem as having a batch size of k
            encoder_out = encoder_out.expand(k, num_pixels, encoder_dim)  # (k, num_pixels, encoder_dim)

            # Tensor to store top k previous words at each step; now they're just <start>
            k_prev_words = torch.as_tensor([[word_encoder.word_map['<start>']]] * k, dtype=torch.long).to(device) #(k,1)

            # Tensor to store top k sequences; now they're just <start>
            seqs = k_prev_words  # (k, 1)

            # Tensor to store top k sequences' scores; now they're just 0
            top_k_scores = torch.zeros(k, 1).to(device)  # (k, 1)

            # Lists to store completed sequences and scores
            complete_seqs = list()
            complete_seqs_scores = list()

            # Start decoding
            step = 1
            h, c = decoder.init_hidden_state(encoder_out)

            # s is a number less than or equal to k, because sequences are removed from this process once they hit <end>
            while True:
                embeddings = decoder.embedding(k_prev_words).squeeze(1)  # (s, embed_dim)

                awe, _ = decoder.attention(encoder_out, h)  # (s, encoder_dim), (s, num_pixels)

                gate = decoder.sigmoid(decoder.f_beta(h))  # gating scalar, (s, encoder_dim)
                awe = gate * awe

                h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c))  # (s, decoder_dim)

                scores = decoder.fc(h)  # (s, vocab_size)
                scores = F.log_softmax(scores, dim=1)

                # Add
                scores = top_k_scores.expand_as(scores) + scores  # (s, vocab_size)

                # For the first step, all k points will have the same scores (since same k previous words, h, c)
                if step == 1:
                    top_k_scores, top_k_words = scores[0].topk(k, 0, True, True)  # (s)
                else:
                    # Unroll and find top scores, and their unrolled indices
                    top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True)  # (s)

                # Convert unrolled indices to actual indices of scores
                vocab_size = len(word_encoder.word_list)
                prev_word_inds = top_k_words / vocab_size  # (s)
                next_word_inds = top_k_words % vocab_size  # (s)
                # print("Top k scores")
                # print(top_k_scores)
                # print(top_k_words)

                # Add new words to sequences
                # print("Prev word inds")
                # print(prev_word_inds)
                # print("Next word idns")
                # print(next_word_inds)
                seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)  # (s, step+1)

                # Which sequences are incomplete (didn't reach <end>)?
                incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
                                   next_word != word_encoder.word_map['<end>']]
                complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))

                # Set aside complete sequences
                if len(complete_inds) > 0:
                    complete_seqs.extend(seqs[complete_inds].tolist())
                    complete_seqs_scores.extend(top_k_scores[complete_inds].cpu().numpy())
                # k -= len(complete_inds)  # reduce beam length accordingly

                # Proceed with incomplete sequences
                # if k == 0:
                #     break
                if len(incomplete_inds) == 0:
                    break

                seqs = seqs[incomplete_inds]
                h = h[prev_word_inds[incomplete_inds]]
                c = c[prev_word_inds[incomplete_inds]]
                encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
                top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
                k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)

                # Break if things have been going on too long
                if step > 50:
                    break
                step += 1
                # print('seqs:', seqs)
                # print('complete lens:', [len(s) for s in complete_seqs])
                # print("Next step\n")

            # i = complete_seqs_scores.index(max(complete_seqs_scores))
            # seq = complete_seqs[i]
            if len(complete_seqs_scores) == 0:
                complete_seqs_scores = top_k_scores.squeeze().cpu().numpy()
                complete_seqs = seqs.cpu().tolist()
            sorted_idxs = np.argsort(np.asarray(complete_seqs_scores) * -1.0)
            if len(sorted_idxs) > beam_size:
                sorted_idxs = sorted_idxs[: beam_size]
            sorted_seqs = [complete_seqs[i] for i in sorted_idxs]

            if bi == 0:  # debug
                # best_i = complete_seqs_scores.index(max(complete_seqs_scores))
                # best_seq = complete_seqs[best_i]
                # print('best:', best_seq, complete_seqs_scores[best_i])
                print('top k:')
                for i, idx in enumerate(sorted_idxs):
                    ignore_wids = [word_encoder.word_map[w] for w in ['<start>', '<end>', '<pad>']]
                    seq = sorted_seqs[i]
                    tokens = [word_encoder.word_list[wid] for wid in seq if wid not in ignore_wids]
                    caption = word_encoder.detokenize(tokens)
                    print(caption, complete_seqs_scores[idx])

            predictions[img_name] = list()
            ignore_wids = [word_encoder.word_map[w] for w in ['<start>', '<end>', '<pad>']]
            for seq in sorted_seqs:
                tokens = [word_encoder.word_list[wid] for wid in seq if wid not in ignore_wids]
                # caption = ' '.join(tokens)
                caption = word_encoder.detokenize(tokens)
                predictions[img_name].append(caption)
            if len(predictions[img_name]) < beam_size:
                predictions[img_name] += [''] * (beam_size - len(predictions[img_name]))
    return predictions
Exemple #10
0
def train():
    from torch.utils.tensorboard import SummaryWriter
    # load configs
    parser = argparse.ArgumentParser(description="Triplet Matching Training")
    parser.add_argument('-c', '--config_file', default=None, help="path to config file")
    parser.add_argument('-o', '--opts', default=None, nargs=argparse.REMAINDER,
                        help="Modify config options using the command-line. E.g. TRAIN.INIT_LR 0.01",)
    args = parser.parse_args()

    if args.config_file is not None:
        cfg.merge_from_file(args.config_file)
    if args.opts is not None:
        cfg.merge_from_list(args.opts)

    prepare(cfg)
    cfg.freeze()
    print(cfg.dump())

    if not os.path.exists(cfg.OUTPUT_PATH):
        os.makedirs(cfg.OUTPUT_PATH)
    with open(os.path.join(cfg.OUTPUT_PATH, 'train.yml'), 'w') as f:
        f.write(cfg.dump())

    # set random seed
    torch.manual_seed(cfg.RAND_SEED)
    np.random.seed(cfg.RAND_SEED)
    random.seed(cfg.RAND_SEED)

    # make data_loader, model, criterion, optimizer
    dataset = TripletTrainData(split=cfg.TRAIN_SPLIT, neg_img=cfg.LOSS.IMG_SENT_WEIGHTS[0] > 0,
                               neg_lang=cfg.LOSS.IMG_SENT_WEIGHTS[1] > 0, lang_input=cfg.LANG_INPUT)
    data_loader = DataLoader(dataset, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=True, drop_last=True, pin_memory=True)

    img_datset = ImgOnlyDataset(split=cfg.EVAL_SPLIT, transform=build_transforms(is_train=False),
                                texture_dataset=dataset)
    eval_img_dataloader = DataLoader(img_datset, batch_size=1, shuffle=False)

    phrase_dataset = PhraseOnlyDataset(texture_dataset=dataset)
    eval_phrase_dataloader = DataLoader(phrase_dataset, batch_size=32, shuffle=False)

    word_encoder = WordEncoder()
    model: TripletMatch = TripletMatch(vec_dim=cfg.MODEL.VEC_DIM, neg_margin=cfg.LOSS.MARGIN,
                                       distance=cfg.MODEL.DISTANCE, img_feats=cfg.MODEL.IMG_FEATS,
                                       lang_encoder_method=cfg.MODEL.LANG_ENCODER, word_encoder=word_encoder)

    if cfg.INIT_WORD_EMBED != 'rand' and cfg.MODEL.LANG_ENCODER in ['mean', 'lstm']:
        word_emb = get_word_embed(word_encoder.word_list, cfg.INIT_WORD_EMBED)
        model.lang_embed.embeds.weight.data.copy_(torch.from_numpy(word_emb))
    if len(cfg.LOAD_WEIGHTS) > 0:
        model.load_state_dict(torch.load(cfg.LOAD_WEIGHTS))

    model.train()
    device = torch.device(cfg.DEVICE)
    model.to(device)

    if not cfg.TRAIN.TUNE_RESNET:
        model.resnet_encoder.requires_grad = False
        model.resnet_encoder.eval()
    if not cfg.TRAIN.TUNE_LANG_ENCODER:
        model.lang_embed.requires_grad = False
        model.lang_embed.eval()

    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                                 lr=cfg.TRAIN.INIT_LR, weight_decay=cfg.TRAIN.WEIGHT_DECAY,
                                 betas=(cfg.TRAIN.ADAM.ALPHA, cfg.TRAIN.ADAM.BETA),
                                 eps=cfg.TRAIN.ADAM.EPSILON)

    # make tensorboard writer and dirs
    checkpoint_dir = os.path.join(cfg.OUTPUT_PATH, 'checkpoints')
    vis_path = os.path.join(cfg.OUTPUT_PATH, 'eval_visualize_%s_LAST' % cfg.EVAL_SPLIT)
    best_vis_path = os.path.join(cfg.OUTPUT_PATH, 'eval_visualize_%s_BEST' % cfg.EVAL_SPLIT)
    tb_dir = os.path.join(cfg.OUTPUT_PATH, 'tensorboard')
    tb_writer = SummaryWriter(tb_dir)
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    if not os.path.exists(tb_dir):
        os.makedirs(tb_dir)

    # training loop
    step = 1
    epoch = 1
    epoch_float = 0.0
    epoch_per_step = cfg.TRAIN.BATCH_SIZE * 1.0 / len(dataset)
    best_eval_metric = 0
    best_metrics = None
    best_eval_count = 0
    early_stop = False
    while epoch <= cfg.TRAIN.MAX_EPOCH and not early_stop:
        # for pos_imgs, pos_langs, neg_imgs, neg_langs in tqdm(data_loader, desc='TRAIN epoch %d' % epoch):
        for pos_imgs, pos_langs, neg_imgs, neg_langs in data_loader:
            pos_imgs = pos_imgs.to(device)
            if neg_imgs is not None and neg_imgs[0] is not None:
                neg_imgs = neg_imgs.to(device)

            verbose = step <= 5 or step % 50 == 0

            neg_img_loss, neg_lang_loss = model(pos_imgs, pos_langs, neg_imgs, neg_langs, verbose=verbose)

            loss = cfg.LOSS.IMG_SENT_WEIGHTS[0] * neg_img_loss + cfg.LOSS.IMG_SENT_WEIGHTS[1] * neg_lang_loss
            loss /= sum(cfg.LOSS.IMG_SENT_WEIGHTS)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            lr = optimizer.param_groups[0]['lr']
            tb_writer.add_scalar('train/loss', loss, step)
            tb_writer.add_scalar('train/neg_img_loss', neg_img_loss, step)
            tb_writer.add_scalar('train/neg_lang_loss', neg_lang_loss, step)
            tb_writer.add_scalar('train/lr', lr, step)

            if verbose:
                print('[%s] epoch-%d step-%d: loss %.4f (neg_img: %.4f, neg_lang: %.4f); lr %.1E'
                      % (time.strftime('%m/%d %H:%M:%S'), epoch, step, loss, neg_img_loss, neg_lang_loss, lr))

            # if epoch == 1 and step == 2:  # debug eval
            #     visualize_path = os.path.join(cfg.OUTPUT_PATH, 'eval_visualize_debug')
            #     do_eval(model, eval_img_dataloader, eval_phrase_dataloader, device, split=cfg.EVAL_SPLIT,
            #             visualize_path=visualize_path, add_to_summary_name=None)

            if epoch_float % cfg.TRAIN.EVAL_EVERY_EPOCH < epoch_per_step and epoch_float > 0:
                p2i_result, i2p_result = do_eval(model, eval_img_dataloader, eval_phrase_dataloader, device,
                                                 split=cfg.EVAL_SPLIT, visualize_path=vis_path)

                for m, v in p2i_result.items():
                    tb_writer.add_scalar('eval_p2i/%s' % m, v, step)
                for m, v in i2p_result.items():
                    tb_writer.add_scalar('eval_i2p/%s' % m, v, step)

                eval_metric = p2i_result['mean_average_precision'] + i2p_result['mean_average_precision']
                if eval_metric > best_eval_metric:
                    print('EVAL: new best!')
                    best_eval_metric = eval_metric
                    best_metrics = (p2i_result, i2p_result)
                    best_eval_count = 0
                    copy_tree(vis_path, best_vis_path, update=1)
                    with open(os.path.join(checkpoint_dir, 'epoch_step.txt'), 'w') as f:
                        f.write('BEST: epoch {}, step {}\n'.format(epoch, step))
                    torch.save(model.state_dict(), os.path.join(checkpoint_dir, 'BEST_checkpoint.pth'))

                else:
                    best_eval_count += 1
                    print('EVAL: since last best: %d' % best_eval_count)
                    if epoch_float % cfg.TRAIN.CHECKPOINT_EVERY_EPOCH < epoch_per_step and epoch_float > 0:
                        with open(os.path.join(checkpoint_dir, 'epoch_step.txt'), 'a') as f:
                            f.write('LAST: epoch {}, step {}\n'.format(epoch, step))
                        torch.save(model.state_dict(), os.path.join(checkpoint_dir, 'LAST_checkpoint.pth'))

                if best_eval_count % cfg.TRAIN.LR_DECAY_EVAL_COUNT == 0 and best_eval_count > 0:
                    print('EVAL: lr decay triggered')
                    for param_group in optimizer.param_groups:
                        param_group['lr'] *= cfg.TRAIN.LR_DECAY_GAMMA

                if best_eval_count % cfg.TRAIN.EARLY_STOP_EVAL_COUNT == 0 and best_eval_count > 0:
                    print('EVAL: early stop triggered')
                    early_stop = True
                    break

                model.train()
                if not cfg.TRAIN.TUNE_RESNET:
                    model.resnet_encoder.eval()
                if not cfg.TRAIN.TUNE_LANG_ENCODER:
                    model.lang_embed.eval()

            step += 1
            epoch_float += epoch_per_step

        epoch += 1

    tb_writer.close()
    if best_metrics is not None:
        exp_name = '%s:%s' % (cfg.OUTPUT_PATH, cfg.EVAL_SPLIT)
        log_to_summary(exp_name, best_metrics[0], best_metrics[1])
    return best_metrics
Exemple #11
0
def main_eval():
    # load configs
    parser = argparse.ArgumentParser(
        description="Triplet (phrase) retrieval evaluation")
    parser.add_argument('-p',
                        '--trained_path',
                        help="path to trained model (where there is cfg file)",
                        default='output/triplet_match/c34_bert_l2_s_lr0.00001')
    parser.add_argument('-m',
                        '--model_file',
                        help='file name of the cached model ',
                        default='BEST_checkpoint.pth')
    parser.add_argument('-o',
                        '--opts',
                        default=None,
                        nargs=argparse.REMAINDER,
                        help="e.g. EVAL_SPLIT test")
    args = parser.parse_args()

    cfg.merge_from_file(os.path.join(args.trained_path, 'train.yml'))
    if args.opts is not None:
        cfg.merge_from_list(args.opts)

    # set random seed
    torch.manual_seed(cfg.RAND_SEED)
    np.random.seed(cfg.RAND_SEED)
    random.seed(cfg.RAND_SEED)

    dataset = TextureDescriptionData(phid_format=None)
    img_dataset = ImgOnlyDataset(split=cfg.EVAL_SPLIT,
                                 transform=build_transforms(is_train=False),
                                 texture_dataset=dataset)
    img_dataloader = DataLoader(img_dataset, batch_size=1, shuffle=False)

    phrase_dataset = PhraseOnlyDataset(texture_dataset=dataset)
    phrase_dataloader = DataLoader(phrase_dataset,
                                   batch_size=32,
                                   shuffle=False)

    model: TripletMatch = TripletMatch(
        vec_dim=cfg.MODEL.VEC_DIM,
        neg_margin=cfg.LOSS.MARGIN,
        distance=cfg.MODEL.DISTANCE,
        img_feats=cfg.MODEL.IMG_FEATS,
        lang_encoder_method=cfg.MODEL.LANG_ENCODER)

    model_path = os.path.join(args.trained_path, 'checkpoints',
                              args.model_file)
    model.load_state_dict(torch.load(model_path))

    device = torch.device(cfg.DEVICE)
    model.to(device)

    do_eval(model,
            img_dataloader,
            phrase_dataloader,
            device,
            split=cfg.EVAL_SPLIT,
            visualize_path=os.path.join(args.trained_path,
                                        'eval_visualize_%s' % cfg.EVAL_SPLIT),
            add_to_summary_name=model_path + ':' + cfg.EVAL_SPLIT)