예제 #1
0
def knowledge_attribute_train(context_text_encoder: TextEncoder,
                              context_image_encoder: ImageEncoder,
                              context_encoder: ContextEncoder,
                              train_dataset: Dataset,
                              valid_dataset: Dataset,
                              test_dataset: Dataset,
                              model_file: str,
                              attribute_data: AttributeData,
                              vocab: Dict[str, int],
                              embed_init=None):
    """Knowledge styletip train.

    Args:
        context_text_encoder (TextEncoder): Context text encoder.
        context_image_encoder (ImageEncoder): Context image encoder.
        context_encoder (ContextEncoder): Context encoder.
        train_dataset (Dataset): Train dataset.
        valid_dataset (Dataset): Valid dataset.
        test_dataset (Dataset): Test dataset.
        model_file (str): Saved model file.
        attribute_data (AttributeData): Attribute data.
        vocab (Dict[str, int]): Vocabulary.
        embed_init: Initial embedding (vocab_size, embed_size).

    """

    # Data loader.
    train_data_loader = DataLoader(
        dataset=train_dataset,
        batch_size=KnowledgeAttributeTrainConfig.batch_size,
        shuffle=True,
        num_workers=KnowledgeAttributeTrainConfig.num_data_loader_workers)

    # Model.
    vocab_size = len(vocab)

    attribute_kv_memory_config = AttributeKVMemoryConfig(
        len(attribute_data.key_vocab), len(attribute_data.value_vocab))
    text_decoder_config = KnowledgeTextDecoderConfig(vocab_size,
                                                     MemoryConfig.memory_size,
                                                     MemoryConfig.output_size,
                                                     embed_init)

    to_hidden = ToHidden(text_decoder_config)
    to_hidden = to_hidden.to(GlobalConfig.device)

    attribute_kv_memory = KVMemory(attribute_kv_memory_config)
    attribute_kv_memory = attribute_kv_memory.to(GlobalConfig.device)

    text_decoder = TextDecoder(text_decoder_config)
    text_decoder = text_decoder.to(GlobalConfig.device)

    # Model parameters.
    params = list(
        chain.from_iterable([
            list(model.parameters()) for model in [
                context_text_encoder, context_image_encoder, context_encoder,
                to_hidden, attribute_kv_memory, text_decoder
            ]
        ]))

    optimizer = Adam(params, lr=KnowledgeAttributeTrainConfig.learning_rate)
    epoch_id = 0
    min_valid_loss = None

    # Load saved state.
    if isfile(model_file):
        state = torch.load(model_file)
        to_hidden.load_state_dict(state['to_hidden'])
        attribute_kv_memory.load_state_dict(state['attribute_kv_memory'])
        text_decoder.load_state_dict(state['text_decoder'])
        optimizer.load_state_dict(state['optimizer'])
        epoch_id = state['epoch_id']
        min_valid_loss = state['min_valid_loss']

    # Loss.
    sum_loss = 0
    bad_loss_cnt = 0

    # Switch to train mode.
    context_text_encoder.train()
    context_image_encoder.train()
    context_encoder.train()
    to_hidden.train()
    attribute_kv_memory.train()
    text_decoder.train()

    finished = False

    for epoch_id in range(epoch_id,
                          KnowledgeAttributeTrainConfig.num_iterations):
        for batch_id, train_data in enumerate(train_data_loader):
            # Set gradients to 0.
            optimizer.zero_grad()

            train_data, products = train_data
            keys, values, pair_length = products
            keys = keys.to(GlobalConfig.device)
            values = values.to(GlobalConfig.device)
            pair_length = pair_length.to(GlobalConfig.device)

            texts, text_lengths, images, utter_types = train_data
            # Sizes:
            # texts: (batch_size, dialog_context_size + 1, dialog_text_max_len)
            # text_lengths: (batch_size, dialog_context_size + 1)
            # images: (batch_size, dialog_context_size + 1,
            #          pos_images_max_num, 3, image_size, image_size)
            # utter_types: (batch_size, )

            # To device.

            texts = texts.to(GlobalConfig.device)
            text_lengths = text_lengths.to(GlobalConfig.device)
            images = images.to(GlobalConfig.device)
            utter_types = utter_types.to(GlobalConfig.device)

            texts.transpose_(0, 1)
            # (dialog_context_size + 1, batch_size, dialog_text_max_len)

            text_lengths.transpose_(0, 1)
            # (dialog_context_size + 1, batch_size)

            images.transpose_(0, 1)
            images.transpose_(1, 2)
            # (dialog_context_size + 1, pos_images_max_num, batch_size, 3,
            #  image_size, image_size)

            # Encode context.
            context, hiddens = encode_context(context_text_encoder,
                                              context_image_encoder,
                                              context_encoder, texts,
                                              text_lengths, images)
            # (batch_size, context_vector_size)

            encode_knowledge_func = partial(attribute_kv_memory, keys, values,
                                            pair_length)

            loss, n_totals = text_loss(to_hidden, text_decoder,
                                       text_decoder_config.text_length,
                                       context, texts[-1], text_lengths[-1],
                                       hiddens, encode_knowledge_func)
            sum_loss += loss / text_decoder_config.text_length

            loss.backward()
            optimizer.step()

            # Print loss every `TrainConfig.print_freq` batches.
            if (batch_id + 1) % KnowledgeAttributeTrainConfig.print_freq == 0:
                cur_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                sum_loss /= KnowledgeAttributeTrainConfig.print_freq
                print('epoch: {} \tbatch: {} \tloss: {} \ttime: {}'.format(
                    epoch_id + 1, batch_id + 1, sum_loss, cur_time))
                sum_loss = 0

            # Valid every `TrainConfig.valid_freq` batches.
            if (batch_id + 1) % KnowledgeAttributeTrainConfig.valid_freq == 0:
                valid_loss = knowledge_attribute_valid(
                    context_text_encoder, context_image_encoder,
                    context_encoder, to_hidden, attribute_kv_memory,
                    text_decoder, valid_dataset,
                    text_decoder_config.text_length)
                cur_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                print('valid_loss: {} \ttime: {}'.format(valid_loss, cur_time))

                # Save current best model.
                if min_valid_loss is None or valid_loss < min_valid_loss:
                    min_valid_loss = valid_loss
                    bad_loss_cnt = 0
                    save_dict = {
                        'task': KNOWLEDGE_ATTRIBUTE_SUBTASK,
                        'epoch_id': epoch_id,
                        'min_valid_loss': min_valid_loss,
                        'optimizer': optimizer.state_dict(),
                        'context_text_encoder':
                        context_text_encoder.state_dict(),
                        'context_image_encoder':
                        context_image_encoder.state_dict(),
                        'context_encoder': context_encoder.state_dict(),
                        'to_hidden': to_hidden.state_dict(),
                        'attribute_kv_memory':
                        attribute_kv_memory.state_dict(),
                        'text_decoder': text_decoder.state_dict()
                    }
                    torch.save(save_dict, model_file)
                    print('Best model saved.')
                else:
                    bad_loss_cnt += 1
                    if bad_loss_cnt > KnowledgeAttributeTrainConfig.patience:
                        knowledge_attribute_test(
                            context_text_encoder, context_image_encoder,
                            context_encoder, to_hidden, attribute_kv_memory,
                            text_decoder, test_dataset,
                            text_decoder_config.text_length, vocab)
                        finished = True
                        break
        if finished:
            break
예제 #2
0
def knowledge_celebrity_valid(
        context_text_encoder: TextEncoder, context_image_encoder: ImageEncoder,
        context_encoder: ContextEncoder, to_hidden: ToHidden,
        celebrity_memory: Memory, text_decoder: TextDecoder,
        valid_dataset: Dataset, celebrity_scores, text_length: int):
    """Knowledge celebrity valid.

    Args:
        context_text_encoder (TextEncoder): Context text encoder.
        context_image_encoder (ImageEncoder): Context image encoder.
        context_encoder (ContextEncoder): Context encoder.
        to_hidden (ToHidden): Context to hidden.
        celebrity_memory (Memory): Celebrity Memory.
        text_decoder (TextDecoder): Text decoder.
        valid_dataset (Dataset): Valid dataset.
        celebrity_scores: Celebrity scores.
        text_length (int): Text length.

    """

    # Valid dataset loader.
    valid_data_loader = DataLoader(
        valid_dataset,
        batch_size=KnowledgeCelebrityValidConfig.batch_size,
        shuffle=True,
        num_workers=KnowledgeCelebrityValidConfig.num_data_loader_workers)

    sum_loss = 0
    num_batches = 0

    # Switch to eval mode.
    context_text_encoder.eval()
    context_image_encoder.eval()
    context_encoder.eval()
    to_hidden.eval()
    celebrity_memory.eval()
    text_decoder.eval()

    with torch.no_grad():
        for batch_id, valid_data in enumerate(valid_data_loader):
            # Only valid `ValidConfig.num_batches` batches.
            if batch_id >= KnowledgeCelebrityValidConfig.num_batches:
                break

            num_batches += 1

            texts, text_lengths, images, utter_types = valid_data
            # Sizes:
            # texts: (batch_size, dialog_context_size + 1, dialog_text_max_len)
            # text_lengths: (batch_size, dialog_context_size + 1)
            # images: (batch_size, dialog_context_size + 1,
            #          pos_images_max_num, 3, image_size, image_size)
            # utter_types: (batch_size, )

            # To device.
            texts = texts.to(GlobalConfig.device)
            text_lengths = text_lengths.to(GlobalConfig.device)
            images = images.to(GlobalConfig.device)
            # utter_types = utter_types.to(GlobalConfig.device)

            texts.transpose_(0, 1)
            # (dialog_context_size + 1, batch_size, dialog_text_max_len)

            text_lengths.transpose_(0, 1)
            # (dialog_context_size + 1, batch_size)

            images.transpose_(0, 1)
            images.transpose_(1, 2)
            # (dialog_context_size + 1, pos_images_max_num, batch_size, 3,
            #  image_size, image_size)

            # Encode context.
            context, hiddens = encode_context(context_text_encoder,
                                              context_image_encoder,
                                              context_encoder, texts,
                                              text_lengths, images)
            # (batch_size, context_vector_size)

            knowledge_entry = celebrity_scores
            encode_knowledge_func = partial(celebrity_memory, knowledge_entry)

            loss, n_totals = text_loss(to_hidden, text_decoder, text_length,
                                       context, texts[-1], text_lengths[-1],
                                       hiddens, encode_knowledge_func)
            sum_loss += loss / text_length

    # Switch to train mode.
    context_text_encoder.train()
    context_image_encoder.train()
    context_encoder.train()
    to_hidden.train()
    celebrity_memory.train()
    text_decoder.train()

    return sum_loss / num_batches
예제 #3
0
def intention_valid(context_text_encoder: TextEncoder,
                    context_image_encoder: ImageEncoder,
                    context_encoder: ContextEncoder, intention: Intention,
                    valid_dataset: Dataset):
    """Intention valid.

    Args:
        context_text_encoder (TextEncoder): Context text encoder.
        context_image_encoder (ImageEncoder): Context image encoder.
        context_encoder (ContextEncoder): Context encoder.
        intention (Intention): Intention.
        valid_dataset (Dataset): Valid dataset.

    """

    # Valid dataset loader.
    valid_data_loader = DataLoader(
        valid_dataset,
        batch_size=IntentionValidConfig.batch_size,
        shuffle=True,
        num_workers=IntentionValidConfig.num_data_loader_workers)

    sum_loss = 0
    sum_accuracy = 0
    num_batches = 0

    # Switch to eval mode.
    context_text_encoder.eval()
    context_image_encoder.eval()
    context_encoder.eval()
    intention.eval()

    with torch.no_grad():
        for batch_id, valid_data in enumerate(valid_data_loader):
            # Only valid `ValidConfig.num_batches` batches.
            if batch_id >= IntentionValidConfig.num_batches:
                break
            num_batches += 1

            texts, text_lengths, images, utter_types = valid_data
            # Sizes:
            # texts: (batch_size, dialog_context_size + 1, dialog_text_max_len)
            # text_lengths: (batch_size, dialog_context_size + 1)
            # images: (batch_size, dialog_context_size + 1,
            #          pos_images_max_num, 3, image_size, image_size)
            # utter_types: (batch_size, )

            # To device.
            texts = texts.to(GlobalConfig.device)
            text_lengths = text_lengths.to(GlobalConfig.device)
            images = images.to(GlobalConfig.device)
            utter_types = utter_types.to(GlobalConfig.device)

            texts.transpose_(0, 1)
            # (dialog_context_size + 1, batch_size, dialog_text_max_len)

            text_lengths.transpose_(0, 1)
            # (dialog_context_size + 1, batch_size)

            images.transpose_(0, 1)
            images.transpose_(1, 2)
            # (dialog_context_size + 1, pos_images_max_num, batch_size, 3,
            #  image_size, image_size)

            # Encode context.
            context, _ = encode_context(context_text_encoder,
                                        context_image_encoder, context_encoder,
                                        texts, text_lengths, images)
            # (batch_size, context_vector_size)

            intent_prob = intention(context)
            # (batch_size, utterance_type_size)

            loss = nll_loss(intent_prob, utter_types)
            sum_loss += loss

            eqs = torch.eq(torch.argmax(intent_prob, dim=1), utter_types)
            accuracy = torch.sum(eqs).item() * 1.0 / eqs.size(0)
            sum_accuracy += accuracy

    # Switch to train mode.
    context_text_encoder.train()
    context_image_encoder.train()
    context_encoder.train()
    intention.train()

    return sum_loss / num_batches, sum_accuracy / num_batches
예제 #4
0
def recommend_train(context_text_encoder: TextEncoder,
                    context_image_encoder: ImageEncoder,
                    context_encoder: ContextEncoder,
                    train_dataset: Dataset,
                    valid_dataset: Dataset,
                    test_dataset: Dataset,
                    model_file: str,
                    vocab_size: int,
                    embed_init=None):
    """Recommend train.
    Args:
        context_text_encoder (TextEncoder): Context text encoder.
        context_image_encoder (ImageEncoder): Context image encoder.
        context_encoder (ContextEncoder): Context encoder.
        train_dataset (Dataset): Train dataset.
        valid_dataset (Dataset): Valid dataset.
        test_dataset (Dataset): Test dataset.
        model_file (str): Saved model file.
        vocab_size (int): Vocabulary size.
        embed_init: Initial embedding (vocab_size, embed_size).
    """

    # Data loader.
    train_data_loader = DataLoader(
        dataset=train_dataset,
        batch_size=RecommendTrainConfig.batch_size,
        shuffle=True,
        num_workers=RecommendTrainConfig.num_data_loader_workers)

    # Model.
    similarity_config = SimilarityConfig(vocab_size, embed_init)
    similarity = Similarity(similarity_config).to(GlobalConfig.device)

    # Model parameters.
    params = list(
        chain.from_iterable([
            list(model.parameters()) for model in [
                context_text_encoder, context_image_encoder, context_encoder,
                similarity
            ]
        ]))

    optimizer = Adam(params, lr=RecommendTrainConfig.learning_rate)
    epoch_id = 0
    min_valid_loss = None

    # Load saved state.
    if isfile(model_file):
        state = torch.load(model_file)
        similarity.load_state_dict(state['similarity'])
        optimizer.load_state_dict(state['optimizer'])
        epoch_id = state['epoch_id']
        min_valid_loss = state['min_valid_loss']

    # Loss.
    sum_loss = 0
    bad_loss_cnt = 0

    # Switch to train mode.
    context_text_encoder.train()
    context_image_encoder.train()
    context_encoder.train()
    similarity.train()

    finished = False

    for epoch_id in range(epoch_id, RecommendTrainConfig.num_iterations):
        for batch_id, train_data in enumerate(train_data_loader):
            # Sets gradients to 0.
            optimizer.zero_grad()

            context_dialog, pos_products, neg_products = train_data

            texts, text_lengths, images, utter_types = context_dialog
            # Sizes:
            # texts: (batch_size, dialog_context_size + 1, dialog_text_max_len)
            # text_lengths: (batch_size, dialog_context_size + 1)
            # images: (batch_size, dialog_context_size + 1,
            #          pos_images_max_num, 3, image_size, image_size)
            # utter_types: (batch_size, )

            batch_size = texts.size(0)

            # To device.
            texts = texts.to(GlobalConfig.device)
            text_lengths = text_lengths.to(GlobalConfig.device)
            images = images.to(GlobalConfig.device)
            # utter_types = utter_types.to(GlobalConfig.device)

            texts.transpose_(0, 1)
            # (dialog_context_size + 1, batch_size, dialog_text_max_len)

            text_lengths.transpose_(0, 1)
            # (dialog_context_size + 1, batch_size)

            images.transpose_(0, 1)
            images.transpose_(1, 2)
            # (dialog_context_size + 1, pos_images_max_num, batch_size, 3,
            #  image_size, image_size)

            # Encode context.
            context, _ = encode_context(context_text_encoder,
                                        context_image_encoder, context_encoder,
                                        texts, text_lengths, images)
            # (batch_size, context_vector_size)

            loss = recommend_loss(similarity, batch_size, context,
                                  pos_products, neg_products)
            sum_loss += loss

            loss.backward()
            optimizer.step()

            # Print loss every `TrainConfig.print_freq` batches.
            if (batch_id + 1) % RecommendTrainConfig.print_freq == 0:
                cur_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                sum_loss /= RecommendTrainConfig.print_freq
                print('epoch: {} \tbatch: {} \tloss: {} \ttime: {}'.format(
                    epoch_id + 1, batch_id + 1, sum_loss, cur_time))
                sum_loss = 0

            # Valid every `TrainConfig.valid_freq` batches.
            if (batch_id + 1) % RecommendTrainConfig.valid_freq == 0:
                valid_loss = recommend_valid(context_text_encoder,
                                             context_image_encoder,
                                             context_encoder, similarity,
                                             valid_dataset)
                cur_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                print('valid_loss: {} \ttime: {}'.format(valid_loss, cur_time))

                # Save current best model.
                if min_valid_loss is None or valid_loss < min_valid_loss:
                    min_valid_loss = valid_loss
                    bad_loss_cnt = 0
                    save_dict = {
                        'task': RECOMMEND_TASK,
                        'epoch_id': epoch_id,
                        'min_valid_loss': min_valid_loss,
                        'optimizer': optimizer.state_dict(),
                        'context_text_encoder':
                        context_text_encoder.state_dict(),
                        'context_image_encoder':
                        context_image_encoder.state_dict(),
                        'context_encoder': context_encoder.state_dict(),
                        'similarity': similarity.state_dict()
                    }
                    torch.save(save_dict, model_file)
                    print('Best model saved.')
                else:
                    bad_loss_cnt += 1

                if bad_loss_cnt > RecommendTrainConfig.patience:
                    recommend_test(context_text_encoder, context_image_encoder,
                                   context_encoder, similarity, test_dataset)
                    finished = True
                    break
        if finished:
            break
    # image_encoder = nn.parallel.DistributedDataParallel(image_encoder, find_unused_parameters=True, device_ids=[local_rank], output_device=local_rank)
    # contrastive_loss = ContrastiveLoss(0.9, max_violation=True, reduction='mean')
    # contrastive_loss = ExponentialLoss()
    for epoch in range(start_epoch, 30):
        # tbar = tqdm(loader)
        if local_rank == 0:
            tbar = tqdm(loader)
        else:
            tbar = loader
        losses_manual_mining = 0.
        losses_hard_mining = 0.
        # losses_gen = 0.
        # losses3 = 0.
        # losses_classify = 0.
        score_model.train()
        image_encoder.train()
        for i, (query, query_len, features, boxes, obj_len, query_neg,
                query_neg_len, features_neg, boxes_neg,
                obj_neg_len) in enumerate(tbar):
            optimizer.zero_grad()
            batch_size = query.size(0)
            target = torch.ones(batch_size).cuda()
            query = query.cuda()
            query_len = query_len.cuda()
            obj_len = obj_len.cuda()
            obj_neg_len = obj_neg_len.cuda()

            boxes = boxes.cuda()
            boxes_neg = boxes_neg.cuda()
            # category = category.cuda()
            # category_neg = category_neg.cuda()
예제 #6
0
def recommend_valid(
        context_text_encoder: TextEncoder,
        context_image_encoder: ImageEncoder,
        context_encoder: ContextEncoder,
        similarity: Similarity,
        valid_dataset: Dataset):
    """Recommend valid.

    Args:
        context_text_encoder (TextEncoder): Context text encoder.
        context_image_encoder (ImageEncoder): Context image encoder.
        context_encoder (ContextEncoder): Context encoder.
        similarity (Similarity): Intention.
        valid_dataset (Dataset): Valid dataset.

    """

    # Valid dataset loader.
    valid_data_loader = DataLoader(
        valid_dataset,
        batch_size=RecommendValidConfig.batch_size,
        shuffle=True,
        num_workers=RecommendValidConfig.num_data_loader_workers
    )

    sum_loss = 0
    num_batches = 0

    # Switch to eval mode.
    context_text_encoder.eval()
    context_image_encoder.eval()
    context_encoder.eval()
    # similarity.eval()
    # There might be a bug in the implement of resnet.

    num_ranks = torch.zeros(DatasetConfig.neg_images_max_num + 1,
                            dtype=torch.long)
    num_ranks = num_ranks.to(GlobalConfig.device)
    total_samples = 0

    with torch.no_grad():
        for batch_id, valid_data in enumerate(valid_data_loader):
            # Only valid `ValidConfig.num_batches` batches.
            if batch_id >= RecommendValidConfig.num_batches:
                break
            num_batches += 1

            context_dialog, pos_products, neg_products = valid_data

            texts, text_lengths, images, utter_types = context_dialog
            # Sizes:
            # texts: (batch_size, dialog_context_size + 1, dialog_text_max_len)
            # text_lengths: (batch_size, dialog_context_size + 1)
            # images: (batch_size, dialog_context_size + 1,
            #          pos_images_max_num, 3, image_size, image_size)
            # utter_types: (batch_size, )

            batch_size = texts.size(0)

            # To device.
            texts = texts.to(GlobalConfig.device)
            text_lengths = text_lengths.to(GlobalConfig.device)
            images = images.to(GlobalConfig.device)
            # utter_types = utter_types.to(GlobalConfig.device)

            texts.transpose_(0, 1)
            # (dialog_context_size + 1, batch_size, dialog_text_max_len)

            text_lengths.transpose_(0, 1)
            # (dialog_context_size + 1, batch_size)

            images.transpose_(0, 1)
            images.transpose_(1, 2)
            # (dialog_context_size + 1, pos_images_max_num, batch_size, 3,
            #  image_size, image_size)

            # Encode context.
            context, _ = encode_context(
                context_text_encoder,
                context_image_encoder,
                context_encoder,
                texts,
                text_lengths,
                images
            )
            # (batch_size, context_vector_size)

            loss = recommend_loss(similarity, batch_size, context,
                                  pos_products, neg_products)
            sum_loss += loss

            num_rank = recommend_eval(
                similarity,
                batch_size,
                context,
                pos_products,
                neg_products
            )

            total_samples += batch_size

            num_ranks += num_rank

    for i in range(DatasetConfig.neg_images_max_num):
        print('total recall@{} = {}'.format(
            i + 1,
            torch.sum(num_ranks[:i + 1]).item() / total_samples))

    # Switch to train mode.
    context_text_encoder.train()
    context_image_encoder.train()
    context_encoder.train()
    similarity.train()

    return sum_loss / num_batches
예제 #7
0
def main(args):
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Image preprocessing
    train_transform = transforms.Compose([
        transforms.RandomCrop(args.image_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),
                             (0.229, 0.224, 0.225))])

    # val_transform = transforms.Compose([
    #     transforms.Resize(args.image_size, interpolation=Image.LANCZOS),
    #     transforms.ToTensor(),
    #     transforms.Normalize((0.485, 0.456, 0.406),
    #                          (0.229, 0.224, 0.225))])

    # Load vocabulary wrapper.
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)
    # Build data loader
    train_data_loader = get_loader(args.train_image_dir, args.train_vqa_path, args.ix_to_ans_file, args.train_description_file, vocab, train_transform, args.batch_size, shuffle=True, num_workers=args.num_workers)
    #val_data_loader = get_loader(args.val_image_dir, args.val_vqa_path, args.ix_to_ans_file, vocab, val_transform, args.batch_size, shuffle=False, num_workers=args.num_workers)

    image_encoder = ImageEncoder(args.img_feature_size)
    question_emb_size = 1024
    # description_emb_size = 512
    no_ans = 1000
    question_encoder = BertEncoder(question_emb_size)
    # ques_description_encoder = BertEncoder(description_emb_size)
    # vqa_decoder = VQA_Model(args.img_feature_size, question_emb_size, description_emb_size, no_ans)
    vqa_decoder = VQA_Model(args.img_feature_size, question_emb_size, no_ans)
    
    

    pretrained_epoch = 0
    if args.pretrained_epoch > 0:
        pretrained_epoch = args.pretrained_epoch
        image_encoder.load_state_dict(torch.load('./models/image_encoder-' + str(pretrained_epoch) + '.pkl'))
        question_encoder.load_state_dict(torch.load('./models/question_encoder-' + str(pretrained_epoch) + '.pkl'))
        # ques_description_encoder.load_state_dict(torch.load('./models/ques_description_encoder-' + str(pretrained_epoch) + '.pkl'))
        vqa_decoder.load_state_dict(torch.load('./models/vqa_decoder-' + str(pretrained_epoch) + '.pkl'))

    if torch.cuda.is_available():
        image_encoder.cuda()
        question_encoder.cuda()
        # ques_description_encoder.cuda()
        vqa_decoder.cuda()
        print("Cuda is enabled...")

    criterion = nn.CrossEntropyLoss()
    # params = image_encoder.get_params() + question_encoder.get_params() + ques_description_encoder.get_params() + vqa_decoder.get_params()
    params = list(image_encoder.parameters()) + list(question_encoder.parameters())  + list(vqa_decoder.parameters())
    #print("params: ", params)
    optimizer = torch.optim.Adam(params, lr=args.learning_rate, weight_decay=args.weight_decay)
    total_train_step = len(train_data_loader)

    min_avg_loss = float("inf")
    overfit_warn = 0

    for epoch in range(args.num_epochs):
        if epoch < pretrained_epoch:
            continue

        image_encoder.train()
        question_encoder.train()
        #ques_description_encoder.train()
        vqa_decoder.train()
        avg_loss = 0.0
        avg_acc = 0.0
        for bi, (question_arr, image_vqa, target_answer, answer_str) in enumerate(train_data_loader):
            loss = 0
            image_encoder.zero_grad()
            question_encoder.zero_grad()
            #ques_description_encoder.zero_grad()
            vqa_decoder.zero_grad()
            
            images = to_var(torch.stack(image_vqa))    
            question_arr = to_var(torch.stack(question_arr))
            #ques_desc_arr = to_var(torch.stack(ques_desc_arr))
            target_answer = to_var(torch.tensor(target_answer))

            image_emb = image_encoder(images)
            question_emb = question_encoder(question_arr)
            #ques_desc_emb = ques_description_encoder(ques_desc_arr)
            #output = vqa_decoder(image_emb, question_emb, ques_desc_emb)
            output = vqa_decoder(image_emb, question_emb)
            
            loss = criterion(output, target_answer)

            _, prediction = torch.max(output,1)
            no_correct_prediction = prediction.eq(target_answer).sum().item()
            accuracy = no_correct_prediction * 100/ args.batch_size

            ####
            target_answer_no = target_answer.tolist()
            prediction_no = prediction.tolist()
            ####
            loss_num = loss.item()
            avg_loss += loss.item()
            avg_acc += no_correct_prediction
            #loss /= (args.batch_size)
            loss.backward()
            optimizer.step()

            # Print log info
            if bi % args.log_step == 0:
                print('Epoch [%d/%d], Train Step [%d/%d], Loss: %.4f, Acc: %.4f'
                      %(epoch + 1, args.num_epochs, bi, total_train_step, loss.item(), accuracy))
            
        avg_loss /= (args.batch_size * total_train_step)
        avg_acc /= (args.batch_size * total_train_step)
        print('Epoch [%d/%d], Average Train Loss: %.4f, Average Train acc: %.4f' %(epoch + 1, args.num_epochs, avg_loss, avg_acc))

        # Save the models
        
        torch.save(image_encoder.state_dict(), os.path.join(args.model_path, 'image_encoder-%d.pkl' %(epoch+1)))
        torch.save(question_encoder.state_dict(), os.path.join(args.model_path, 'question_encoder-%d.pkl' %(epoch+1)))
        #torch.save(ques_description_encoder.state_dict(), os.path.join(args.model_path, 'ques_description_encoder-%d.pkl' %(epoch+1)))
        torch.save(vqa_decoder.state_dict(), os.path.join(args.model_path, 'vqa_decoder-%d.pkl' %(epoch+1)))

        overfit_warn = overfit_warn + 1 if (min_avg_loss < avg_loss) else 0
        min_avg_loss = min(min_avg_loss, avg_loss)
        lossFileName = "result/result_"+str(epoch)+".txt"
        test_fd = open(lossFileName, 'w')
        test_fd.write('Epoch: '+ str(epoch) + ' avg_loss: ' + str(avg_loss)+ " avg_acc: "+ str(avg_acc)+"\n")
        test_fd.close()

        if overfit_warn >= 5:
            print("terminated as overfitted")
            break
예제 #8
0
def intention_train(context_text_encoder: TextEncoder,
                    context_image_encoder: ImageEncoder,
                    context_encoder: ContextEncoder, train_dataset: Dataset,
                    valid_dataset: Dataset, test_dataset: Dataset,
                    model_file: str):
    """Intention train.
    Args:
        context_text_encoder (TextEncoder): Context text encoder.
        context_image_encoder (ImageEncoder): Context image encoder.
        context_encoder (ContextEncoder): Context encoder.
        train_dataset (Dataset): Train dataset.
        valid_dataset (Dataset): Valid dataset.
        test_dataset (Dataset): Test dataset.
        model_file (str): Saved model file.
    """

    # Data loader.
    train_data_loader = DataLoader(
        dataset=train_dataset,
        batch_size=IntentionTrainConfig.batch_size,
        shuffle=True,
        num_workers=IntentionTrainConfig.num_data_loader_workers)

    # Model.
    intention_config = IntentionConfig()
    intention = Intention(intention_config).to(GlobalConfig.device)

    # Model parameters.
    params = list(
        chain.from_iterable([
            list(model.parameters()) for model in [
                context_text_encoder, context_image_encoder, context_encoder,
                intention
            ]
        ]))

    optimizer = Adam(params, lr=IntentionTrainConfig.learning_rate)
    epoch_id = 0
    min_valid_loss = None

    # Load saved state.
    if isfile(model_file):
        state = torch.load(model_file)
        intention.load_state_dict(state['intention'])
        optimizer.load_state_dict(state['optimizer'])
        epoch_id = state['epoch_id']
        min_valid_loss = state['min_valid_loss']

    # Loss.
    sum_loss = 0
    bad_loss_cnt = 0

    # Switch to train mode.
    context_text_encoder.train()
    context_image_encoder.train()
    context_encoder.train()
    intention.train()

    finished = False

    for epoch_id in range(epoch_id, IntentionTrainConfig.num_iterations):
        for batch_id, train_data in enumerate(train_data_loader):
            # Sets gradients to 0.
            optimizer.zero_grad()

            texts, text_lengths, images, utter_types = train_data
            # Sizes:
            # texts: (batch_size, dialog_context_size + 1, dialog_text_max_len)
            # text_lengths: (batch_size, dialog_context_size + 1)
            # images: (batch_size, dialog_context_size + 1,
            #          pos_images_max_num, 3, image_size, image_size)
            # utter_types: (batch_size, )

            # To device.
            texts = texts.to(GlobalConfig.device)
            text_lengths = text_lengths.to(GlobalConfig.device)
            images = images.to(GlobalConfig.device)
            utter_types = utter_types.to(GlobalConfig.device)

            texts.transpose_(0, 1)
            # (dialog_context_size + 1, batch_size, dialog_text_max_len)

            text_lengths.transpose_(0, 1)
            # (dialog_context_size + 1, batch_size)

            images.transpose_(0, 1)
            images.transpose_(1, 2)
            # (dialog_context_size + 1, pos_images_max_num, batch_size, 3,
            #  image_size, image_size)

            # Encode context.
            context, _ = encode_context(context_text_encoder,
                                        context_image_encoder, context_encoder,
                                        texts, text_lengths, images)
            # (batch_size, context_vector_size)

            intent_prob = intention(context)
            # (batch_size, utterance_type_size)

            loss = nll_loss(intent_prob, utter_types)
            sum_loss += loss

            loss.backward()
            optimizer.step()

            # Print loss every `TrainConfig.print_freq` batches.
            if (batch_id + 1) % IntentionTrainConfig.print_freq == 0:
                cur_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                sum_loss /= IntentionTrainConfig.print_freq
                print('epoch: {} \tbatch: {} \tloss: {} \ttime: {}'.format(
                    epoch_id + 1, batch_id + 1, sum_loss, cur_time))
                sum_loss = 0

            # Valid every `TrainConfig.valid_freq` batches.
            if (batch_id + 1) % IntentionTrainConfig.valid_freq == 0:
                valid_loss, accuracy = intention_valid(context_text_encoder,
                                                       context_image_encoder,
                                                       context_encoder,
                                                       intention,
                                                       valid_dataset)
                cur_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                print('valid_loss: {} \taccuracy: {} \ttime: {}'.format(
                    valid_loss, accuracy, cur_time))

                # Save current best model.
                if min_valid_loss is None or valid_loss < min_valid_loss:
                    min_valid_loss = valid_loss
                    bad_loss_cnt = 0
                    save_dict = {
                        'task': INTENTION_TASK,
                        'epoch_id': epoch_id,
                        'min_valid_loss': min_valid_loss,
                        'optimizer': optimizer.state_dict(),
                        'context_text_encoder':
                        context_text_encoder.state_dict(),
                        'context_image_encoder':
                        context_image_encoder.state_dict(),
                        'context_encoder': context_encoder.state_dict(),
                        'intention': intention.state_dict()
                    }
                    torch.save(save_dict, model_file)
                    print('Best model saved.')
                else:
                    bad_loss_cnt += 1
                    if bad_loss_cnt > IntentionTrainConfig.patience:
                        intention_test(context_text_encoder,
                                       context_image_encoder, context_encoder,
                                       intention, test_dataset)
                        finished = True
                        break
        if finished:
            break