示例#1
0
def main(args):
    """
    Training and validation.
    """
    with open(args.word_map_file, 'rb') as f:
        word_map = pickle.load(f)


    #make choice wich ecoder to use
    encoder = Encoder(input_feature_dim=args.input_feature_dim,
                         encoder_hidden_dim=args.encoder_hidden_dim,
                         encoder_layer=args.encoder_layer,
                         rnn_unit=args.rnn_unit,
                         use_gpu=args.CUDA,
                         dropout_rate=args.dropout_rate
                         )

    #encoder = EncoderT(input_feature_dim=args.input_feature_dim,
     #                    encoder_hidden_dim=args.encoder_hidden_dim,
     #                    encoder_layer=args.encoder_layer,
     #                   rnn_unit=args.rnn_unit,
     #                    use_gpu=args.CUDA,
     #                    dropout_rate=args.dropout,
     #                    nhead=args.nhead
     #                    )

    decoder = DecoderWithAttention(attention_dim=args.attention_dim,
                                    embed_dim=args.emb_dim,
                                    decoder_dim=args.decoder_dim,
                                    vocab_size=len(word_map),
                                    dropout=args.dropout)
                                    
    if args.resume:
        encoder.load_state_dict(torch.load(args.encoder_path))
        decoder.load_state_dict(torch.load(args.decoder_path))


    encoder_parameter = [p for p in encoder.parameters() if p.requires_grad] # selecting every parameter.
    decoder_parameter = [p for p in decoder.parameters() if p.requires_grad]
   
    encoder_optimizer = torch.optim.Adam(encoder_parameter,lr=args.decoder_lr) #Adam selected
    decoder_optimizer = torch.optim.Adam(decoder_parameter,lr=args.decoder_lr)
    
    if args.CUDA:
        decoder = decoder.cuda()    
        encoder = encoder.cuda()

    if args.CUDA:
        criterion = nn.CrossEntropyLoss().cuda()
    else:
        criterion = nn.CrossEntropyLoss() #gewoon naar cpu dan
    
    train_loader = torch.utils.data.DataLoader(
        CaptionDataset(args.data_path, split='TRAIN'),
        batch_size=args.batch_size, shuffle=True, num_workers=args.workers, collate_fn=pad_collate_train,pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        CaptionDataset(args.data_path,split='VAL'),
        batch_size=args.batch_size, shuffle=False, num_workers=args.workers,collate_fn=pad_collate_train, pin_memory=True)

    # Epochs
    best_bleu4 = 0
    for epoch in range(args.start_epoch, args.epochs):
        
        losst = train(train_loader=train_loader,  ## deze los is de trainit_weight loss! 
              encoder = encoder,              
              decoder=decoder,
              criterion=criterion,  
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,           
              epoch=epoch,args=args)

        # One epoch's validation
        if epoch%1==0:
            lossv = validate(val_loader=val_loader,   
                            encoder=encoder,
                            decoder=decoder,
                            criterion=criterion,
                            best_bleu=best_bleu4,
                            args=args) 

        info = 'LOSST - {losst:.4f}, LOSSv - {lossv:.4f}\n'.format(
                losst=losst,
                lossv=lossv)


        with open(dev, "a") as f:    ## de los moet ook voor de validation 
            f.write(info)
            f.write("\n")  

        #Selecteren op basis van Bleu gaat als volgt:    
        #print('BLEU4: ' + bleu4)
        #print('best_bleu4 '+ best_bleu4)
        #if bleu4>best_bleu4:
        if epoch %3 ==0:
            save_checkpoint(epoch, encoder, decoder, encoder_optimizer,
                            decoder_optimizer, lossv)
示例#2
0
encoder = Encoder()
encoder.fine_tune(fine_tune_encoder)
encoder_optimizer = torch.optim.Adam(
    params=filter(lambda p: p.requires_grad, encoder.parameters()),
    lr=encoder_lr) if fine_tune_encoder else None

decoder_sched = torch.optim.lr_scheduler.CosineAnnealingLR(decoder_optimizer,
                                                           5,
                                                           eta_min=1e-5,
                                                           last_epoch=-1)
encoder_sched = torch.optim.lr_scheduler.CosineAnnealingLR(encoder_optimizer,
                                                           5,
                                                           eta_min=1e-5,
                                                           last_epoch=-1)
encoder = encoder.cuda()
decoder = decoder.cuda()

train_loader = generate_data_loader(train_root, 64, int(150000))
val_loader = generate_data_loader(val_root, 50, int(10000))
criterion = nn.CrossEntropyLoss().to(device)


class AverageMeter(object):
    """
    Keeps track of most recent, average, sum, and count of a metric.
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
示例#3
0
def train(args):
    cfg_from_file(args.cfg)
    cfg.WORKERS = args.num_workers
    pprint.pprint(cfg)
    # set the seed manually
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    # define outputer
    outputer_train = Outputer(args.output_dir, cfg.IMAGETEXT.PRINT_EVERY,
                              cfg.IMAGETEXT.SAVE_EVERY)
    outputer_val = Outputer(args.output_dir, cfg.IMAGETEXT.PRINT_EVERY,
                            cfg.IMAGETEXT.SAVE_EVERY)
    # define the dataset
    split_dir, bshuffle = 'train', True

    # Get data loader
    imsize = cfg.TREE.BASE_SIZE * (2**(cfg.TREE.BRANCH_NUM - 1))
    train_transform = transforms.Compose([
        transforms.Scale(int(imsize * 76 / 64)),
        transforms.RandomCrop(imsize),
    ])
    val_transform = transforms.Compose([
        transforms.Scale(int(imsize * 76 / 64)),
        transforms.CenterCrop(imsize),
    ])
    if args.dataset == 'bird':
        train_dataset = ImageTextDataset(args.data_dir,
                                         split_dir,
                                         transform=train_transform,
                                         sample_type='train')
        val_dataset = ImageTextDataset(args.data_dir,
                                       'val',
                                       transform=val_transform,
                                       sample_type='val')
    elif args.dataset == 'coco':
        train_dataset = CaptionDataset(args.data_dir,
                                       split_dir,
                                       transform=train_transform,
                                       sample_type='train',
                                       coco_data_json=args.coco_data_json)
        val_dataset = CaptionDataset(args.data_dir,
                                     'val',
                                     transform=val_transform,
                                     sample_type='val',
                                     coco_data_json=args.coco_data_json)
    else:
        raise NotImplementedError

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.IMAGETEXT.BATCH_SIZE,
        shuffle=bshuffle,
        num_workers=int(cfg.WORKERS))
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=cfg.IMAGETEXT.BATCH_SIZE,
        shuffle=False,
        num_workers=1)
    # define the model and optimizer
    if args.raw_checkpoint != '':
        encoder, decoder = load_raw_checkpoint(args.raw_checkpoint)
    else:
        encoder = Encoder()
        decoder = DecoderWithAttention(
            attention_dim=cfg.IMAGETEXT.ATTENTION_DIM,
            embed_dim=cfg.IMAGETEXT.EMBED_DIM,
            decoder_dim=cfg.IMAGETEXT.DECODER_DIM,
            vocab_size=train_dataset.n_words)
        # load checkpoint
        if cfg.IMAGETEXT.CHECKPOINT != '':
            outputer_val.log("load model from: {}".format(
                cfg.IMAGETEXT.CHECKPOINT))
            encoder, decoder = load_checkpoint(encoder, decoder,
                                               cfg.IMAGETEXT.CHECKPOINT)

    encoder.fine_tune(False)
    # to cuda
    encoder = encoder.cuda()
    decoder = decoder.cuda()
    loss_func = torch.nn.CrossEntropyLoss()
    if args.eval:  # eval only
        outputer_val.log("only eval the model...")
        assert cfg.IMAGETEXT.CHECKPOINT != ''
        val_rtn_dict, outputer_val = validate_one_epoch(
            0, val_dataloader, encoder, decoder, loss_func, outputer_val)
        outputer_val.log("\n[valid]: {}\n".format(dict2str(val_rtn_dict)))
        return

    # define optimizer
    optimizer_encoder = torch.optim.Adam(encoder.parameters(),
                                         lr=cfg.IMAGETEXT.ENCODER_LR)
    optimizer_decoder = torch.optim.Adam(decoder.parameters(),
                                         lr=cfg.IMAGETEXT.DECODER_LR)
    encoder_lr_scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer_encoder, step_size=10, gamma=cfg.IMAGETEXT.LR_GAMMA)
    decoder_lr_scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer_decoder, step_size=10, gamma=cfg.IMAGETEXT.LR_GAMMA)
    print("train the model...")
    for epoch_idx in range(cfg.IMAGETEXT.EPOCH):
        # val_rtn_dict, outputer_val = validate_one_epoch(epoch_idx, val_dataloader, encoder,
        #         decoder, loss_func, outputer_val)
        # outputer_val.log("\n[valid] epoch: {}, {}".format(epoch_idx, dict2str(val_rtn_dict)))
        train_rtn_dict, outputer_train = train_one_epoch(
            epoch_idx, train_dataloader, encoder, decoder, optimizer_encoder,
            optimizer_decoder, loss_func, outputer_train)
        # adjust lr scheduler
        encoder_lr_scheduler.step()
        decoder_lr_scheduler.step()

        outputer_train.log("\n[train] epoch: {}, {}\n".format(
            epoch_idx, dict2str(train_rtn_dict)))
        val_rtn_dict, outputer_val = validate_one_epoch(
            epoch_idx, val_dataloader, encoder, decoder, loss_func,
            outputer_val)
        outputer_val.log("\n[valid] epoch: {}, {}\n".format(
            epoch_idx, dict2str(val_rtn_dict)))

        outputer_val.save_step({
            "encoder": encoder.state_dict(),
            "decoder": decoder.state_dict()
        })
    outputer_val.save({
        "encoder": encoder.state_dict(),
        "decoder": decoder.state_dict()
    })