def main(beam_search_type, args):

    print(f'The arguments are')
    print(args)
    word_map_path = f'./dataset/wordmap_{args.dataset}.json'
    word_map = json.load(open(word_map_path, 'r'))

    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    val_transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    print('==========Loading Data==========')
    val_data = ImagecapDatasetFromFeature(
        args.dataset,
        args.test_split,
        val_transform,
    )
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)
    print(len(val_loader))
    print('==========Data Loaded==========')
    print('==========Setting Model==========')
    if args.model_type == 'adaptive':
        model = adaptiveattention.AdaptiveAttentionCaptioningModel(
            args.embed_dim, args.hidden_dim, len(word_map), args.encoder)

    elif args.model_type == 'gridtd':
        model = gridTDmodel.GridTDModelBU(args.embed_dim, args.hidden_dim,
                                          len(word_map), args.encoder)
    elif args.model_type == 'aoa':
        model = aoamodel.AOAModelBU(args.embed_dim, args.hidden_dim,
                                    args.num_head, len(word_map), args.encoder)
    else:
        raise NotImplementedError(
            f'model_type {args.model_type} does not available yet')
    model.cuda()

    if args.weight:
        print(f'==========Resuming weights from {args.weight}==========')
        checkpoint = torch.load(args.weight)
        start_epoch = checkpoint['epoch']
        # epochs_since_improvement = checkpoint['epochs_since_improvement']
        # best_cider = checkpoint['cider']
        model.load_state_dict(checkpoint['state_dict'])
    else:
        print(f'==========Initializing model from random==========')
        start_epoch = 0
        epochs_since_improvement = 0
        best_cider = 0
    print(f'==========Start Testing==========')
    validate(val_loader,
             model,
             word_map,
             args,
             beam_search_type=beam_search_type,
             start_epoch=start_epoch)
def main(args):
    print(f'The arguments are')
    print(args)
    print(f'model_type is {args.model_type}')
    word_map_path = f'./dataset/wordmap_{args.dataset}.json'
    word_map = json.load(open(word_map_path, 'r'))

    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    train_transform = transforms.Compose([
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.RandomHorizontalFlip(),
        # transforms.RandomResizedCrop(size=(args.height, args.width), scale=(args.scale_min, args.scale_max)),
        # transforms.RandomRotation((args.rotate_min, args.rotate_max)),
        transforms.Resize(size=(args.height, args.width)),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
    val_transform = transforms.Compose([
        transforms.Resize(size=(args.height, args.width)),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
    print('==========Loading Data==========')
    train_data = ImagecapDataset(
        args.dataset,
        'train',
        train_transform,
    )
    val_data = ImagecapDataset(
        args.dataset,
        'val',
        val_transform,
    )
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=None)
    print(len(train_loader))
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)
    print(len(val_loader))
    print('==========Data Loaded==========')
    print('==========Setting Model==========')
    if args.model_type == 'adaptive':
        model = adaptiveattention.AdaptiveAttentionCaptioningModel(
            args.embed_dim, args.hidden_dim, len(word_map), args.encoder)
        img_encoder_params = [{
            'params': model.img_encoder.parameters(),
            'lr': args.encoder_lr
        }]
        decoder_parameters = [{
            'params': model.img_projector.parameters()
        }, {
            'params':
            model.global_img_feature_proj.parameters()
        }, {
            'params': model.AdaLSTM.parameters()
        }, {
            'params': model.AdaAttention.parameters()
        }, {
            'params': model.embedding.parameters()
        }, {
            'params': model.fc.parameters()
        }]
    elif args.model_type == 'gridtd':
        model = gridTDmodel.GridTDModel(args.embed_dim, args.hidden_dim,
                                        len(word_map), args.encoder)
        img_encoder_params = [{
            'params': model.img_encoder.parameters(),
            'lr': args.encoder_lr
        }]
        decoder_parameters = [{
            'params': model.img_projector.parameters()
        }, {
            'params':
            model.global_img_feature_proj.parameters()
        }, {
            'params': model.AdaLSTM.parameters()
        }, {
            'params': model.LanguageLSTM.parameters()
        }, {
            'params': model.AdaAttention.parameters()
        }, {
            'params': model.embedding.parameters()
        }, {
            'params': model.fc.parameters()
        }]
    elif args.model_type == 'aoa':
        model = aoamodel.AOAModel(args.embed_dim, args.hidden_dim,
                                  args.num_head, len(word_map), args.encoder)
        img_encoder_params = [{
            'params': model.img_encoder.parameters(),
            'lr': args.encoder_lr
        }]
        decoder_parameters = [{
            'params': model.img_projector.parameters()
        }, {
            'params': model.LanguageLSTM.parameters()
        }, {
            'params': model.decoder_k_proj.parameters()
        }, {
            'params': model.decoder_v_proj.parameters()
        }, {
            'params':
            model.decoder_multihead_attention.parameters()
        }, {
            'params': model.decoder_aoa_linear.parameters()
        }, {
            'params':
            model.decoder_aoa_linear_gate.parameters()
        }, {
            'params': model.embedding.parameters()
        }, {
            'params': model.fc.parameters()
        }]
    else:
        raise NotImplementedError(
            f'model_type {args.model_type} does not available yet')
    model.cuda()

    if args.resume:
        print(f'==========Resuming weights from {args.resume}==========')
        checkpoint = torch.load(args.resume)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_cider = checkpoint['cider']
        model.load_state_dict(checkpoint['state_dict'])
    else:
        print(f'==========Initializing model from random==========')
        start_epoch = 0
        epochs_since_improvement = 0
        best_cider = 0
    if args.finetune_encoder:
        print(f'==========Training with finetuning CNN==========')
        optimizer = torch.optim.Adam(params=img_encoder_params +
                                     decoder_parameters,
                                     lr=args.decoder_lr,
                                     betas=(0.8, 0.999))
    else:
        print(f'==========Training with fixed CNN==========')
        for name, param in model.named_parameters():
            if 'img_encoder' in name:
                param.requires_grad = False
            if param.requires_grad:
                print(name, param.data.size())
        optimizer = torch.optim.Adam(params=decoder_parameters,
                                     lr=args.decoder_lr,
                                     betas=(0.8, 0.999))

    print(f'==========Start Training==========')
    for epoch in range(start_epoch, args.epochs):
        # if args.model_type == 'aoa':
        #     if epoch > 0 and (epoch)%3==0:
        #         mutils.adjust_learning_rate(optimizer, 0.8, 2e-5)
        if epochs_since_improvement >= 2:
            mutils.adjust_learning_rate(optimizer, 0.8, 2e-5)
            epochs_since_improvement = 0
        if args.cider_tune:
            print(f'==========Training with Cider Optm==========')
            criterion = mutils.RewardCriterion().cuda()
            train_func = traincider
        elif args.lrp_tune:
            print(f'==========Training with lrp Optm==========')
            criterion = torch.nn.CrossEntropyLoss(
                ignore_index=word_map['<pad>']).cuda()
            train_func = train_lrp
        elif args.lrp_cider_tune:
            print(f'==========Training with lrp cider Optm==========')
            criterion = mutils.RewardCriterion().cuda()
            train_func = trainciderlrp
        else:
            print(f'==========Training ==========')
            criterion = torch.nn.CrossEntropyLoss(
                ignore_index=word_map['<pad>']).cuda()
            train_func = train
            # args.ss_prob = args.ss_prob + (epoch //10) * 0.03
            # print(f'Traning with ss_prob {args.ss_prob}')
        train_func(train_loader, model, criterion, optimizer, epoch,
                   args.ss_prob, word_map, args.print_freq, args.grad_clip)

        bleu, cider = validate(val_loader,
                               model,
                               word_map,
                               3,
                               epoch,
                               beam_search_type='beam_search')
        is_best = cider > best_cider
        best_cider = max(cider, best_cider)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement))
        else:
            epochs_since_improvement = 0
        if args.lrp_tune:
            mutils.save_checkpoint(args.dataset,
                                   str(epoch) + 'lrp',
                                   epochs_since_improvement, model, optimizer,
                                   bleu, cider, is_best, args.save_path,
                                   args.encoder)
        else:
            mutils.save_checkpoint(args.dataset, epoch,
                                   epochs_since_improvement, model, optimizer,
                                   bleu, cider, is_best, args.save_path,
                                   args.encoder)
Пример #3
0
def main(args):
    print(f'The arguments are')
    print(args)
    print(f'model_type is {args.model_type}')
    word_map_path = f'./dataset/wordmap_{args.dataset}.json'
    word_map = json.load(open(word_map_path, 'r'))

    train_transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    val_transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    print('==========Loading Data==========')
    train_data = ImagecapDatasetFromFeature(
        args.dataset,
        'train',
        train_transform,
    )
    val_data = ImagecapDatasetFromFeature(
        args.dataset,
        'val',
        val_transform,
    )
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=None)
    print(len(train_loader))

    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)
    print(len(val_loader))

    print('==========Data Loaded==========')
    print('==========Setting Model==========')
    if args.model_type == 'adaptive':
        model = adaptiveattention.AdaptiveAttentionCaptioningModel(
            args.embed_dim, args.hidden_dim, len(word_map), args.encoder)
    elif args.model_type == 'gridtd':
        model = gridTDmodel.GridTDModelBU(args.embed_dim, args.hidden_dim,
                                          len(word_map), args.encoder)
    elif args.model_type == 'aoa':
        model = aoamodel.AOAModelBU(args.embed_dim, args.hidden_dim,
                                    args.num_head, len(word_map), args.encoder)
    else:
        raise NotImplementedError(
            f'model_type {args.model_type} does not available yet')
    model.cuda()

    if args.resume:
        print(f'==========Resuming weights from {args.resume}==========')
        checkpoint = torch.load(args.resume)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_cider = checkpoint['cider']
        model.load_state_dict(checkpoint['state_dict'])
    else:
        print(f'==========Initializing model from random==========')
        start_epoch = 0
        epochs_since_improvement = 0
        best_cider = 0
    optimizer = torch.optim.Adam(params=model.parameters(),
                                 lr=args.decoder_lr,
                                 betas=(0.8, 0.999))

    print(f'==========Start Training==========')
    for epoch in range(start_epoch, args.epochs):
        # if args.model_type == 'aoa':
        #     if epoch > 0 and (epoch)%3==0:
        #         mutils.adjust_learning_rate(optimizer, 0.8, 2e-5)
        if epochs_since_improvement >= 2:
            mutils.adjust_learning_rate(optimizer, 0.8, 2e-5)
            epochs_since_improvement = 0
        if args.cider_tune:
            print(f'==========Training with Cider Optm==========')
            criterion = mutils.RewardCriterion().cuda()
            train_func = traincider
        elif args.lrp_tune:
            print(f'==========Training with lrp Optm==========')
            criterion = torch.nn.CrossEntropyLoss(
                ignore_index=word_map['<pad>']).cuda()
            train_func = train_lrp
        elif args.lrp_cider_tune:
            print(f'==========Training with lrp cider Optm==========')
            criterion = mutils.RewardCriterion().cuda()
            train_func = trainciderlrp
        else:
            print(f'==========Training ==========')
            criterion = torch.nn.CrossEntropyLoss(
                ignore_index=word_map['<pad>']).cuda()
            train_func = train
            # args.ss_prob = args.ss_prob + (epoch //10) * 0.03
            # print(f'Traning with ss_prob {args.ss_prob}')
        train_func(train_loader, model, criterion, optimizer, epoch,
                   args.ss_prob, word_map, args.print_freq, args.grad_clip)

        bleu, cider = validate(val_loader,
                               model,
                               word_map,
                               3,
                               epoch,
                               beam_search_type='beam_search')
        is_best = cider > best_cider
        best_cider = max(cider, best_cider)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement))
        else:
            epochs_since_improvement = 0
        if args.lrp_tune:
            mutils.save_checkpoint(args.dataset,
                                   str(epoch) + 'lrp',
                                   epochs_since_improvement, model, optimizer,
                                   bleu, cider, is_best, args.save_path,
                                   args.encoder)
        else:
            mutils.save_checkpoint(args.dataset, epoch,
                                   epochs_since_improvement, model, optimizer,
                                   bleu, cider, is_best, args.save_path,
                                   args.encoder)