def main(args): """ Training and validation. """ with open(args.word_map_file, 'rb') as f: word_map = pickle.load(f) #make choice wich ecoder to use encoder = Encoder(input_feature_dim=args.input_feature_dim, encoder_hidden_dim=args.encoder_hidden_dim, encoder_layer=args.encoder_layer, rnn_unit=args.rnn_unit, use_gpu=args.CUDA, dropout_rate=args.dropout_rate ) #encoder = EncoderT(input_feature_dim=args.input_feature_dim, # encoder_hidden_dim=args.encoder_hidden_dim, # encoder_layer=args.encoder_layer, # rnn_unit=args.rnn_unit, # use_gpu=args.CUDA, # dropout_rate=args.dropout, # nhead=args.nhead # ) decoder = DecoderWithAttention(attention_dim=args.attention_dim, embed_dim=args.emb_dim, decoder_dim=args.decoder_dim, vocab_size=len(word_map), dropout=args.dropout) if args.resume: encoder.load_state_dict(torch.load(args.encoder_path)) decoder.load_state_dict(torch.load(args.decoder_path)) encoder_parameter = [p for p in encoder.parameters() if p.requires_grad] # selecting every parameter. decoder_parameter = [p for p in decoder.parameters() if p.requires_grad] encoder_optimizer = torch.optim.Adam(encoder_parameter,lr=args.decoder_lr) #Adam selected decoder_optimizer = torch.optim.Adam(decoder_parameter,lr=args.decoder_lr) if args.CUDA: decoder = decoder.cuda() encoder = encoder.cuda() if args.CUDA: criterion = nn.CrossEntropyLoss().cuda() else: criterion = nn.CrossEntropyLoss() #gewoon naar cpu dan train_loader = torch.utils.data.DataLoader( CaptionDataset(args.data_path, split='TRAIN'), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, collate_fn=pad_collate_train,pin_memory=True) val_loader = torch.utils.data.DataLoader( CaptionDataset(args.data_path,split='VAL'), batch_size=args.batch_size, shuffle=False, num_workers=args.workers,collate_fn=pad_collate_train, pin_memory=True) # Epochs best_bleu4 = 0 for epoch in range(args.start_epoch, args.epochs): losst = train(train_loader=train_loader, ## deze los is de trainit_weight loss! encoder = encoder, decoder=decoder, criterion=criterion, encoder_optimizer=encoder_optimizer, decoder_optimizer=decoder_optimizer, epoch=epoch,args=args) # One epoch's validation if epoch%1==0: lossv = validate(val_loader=val_loader, encoder=encoder, decoder=decoder, criterion=criterion, best_bleu=best_bleu4, args=args) info = 'LOSST - {losst:.4f}, LOSSv - {lossv:.4f}\n'.format( losst=losst, lossv=lossv) with open(dev, "a") as f: ## de los moet ook voor de validation f.write(info) f.write("\n") #Selecteren op basis van Bleu gaat als volgt: #print('BLEU4: ' + bleu4) #print('best_bleu4 '+ best_bleu4) #if bleu4>best_bleu4: if epoch %3 ==0: save_checkpoint(epoch, encoder, decoder, encoder_optimizer, decoder_optimizer, lossv)
encoder = Encoder() encoder.fine_tune(fine_tune_encoder) encoder_optimizer = torch.optim.Adam( params=filter(lambda p: p.requires_grad, encoder.parameters()), lr=encoder_lr) if fine_tune_encoder else None decoder_sched = torch.optim.lr_scheduler.CosineAnnealingLR(decoder_optimizer, 5, eta_min=1e-5, last_epoch=-1) encoder_sched = torch.optim.lr_scheduler.CosineAnnealingLR(encoder_optimizer, 5, eta_min=1e-5, last_epoch=-1) encoder = encoder.cuda() decoder = decoder.cuda() train_loader = generate_data_loader(train_root, 64, int(150000)) val_loader = generate_data_loader(val_root, 50, int(10000)) criterion = nn.CrossEntropyLoss().to(device) class AverageMeter(object): """ Keeps track of most recent, average, sum, and count of a metric. """ def __init__(self): self.reset() def reset(self): self.val = 0
def train(args): cfg_from_file(args.cfg) cfg.WORKERS = args.num_workers pprint.pprint(cfg) # set the seed manually np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) # define outputer outputer_train = Outputer(args.output_dir, cfg.IMAGETEXT.PRINT_EVERY, cfg.IMAGETEXT.SAVE_EVERY) outputer_val = Outputer(args.output_dir, cfg.IMAGETEXT.PRINT_EVERY, cfg.IMAGETEXT.SAVE_EVERY) # define the dataset split_dir, bshuffle = 'train', True # Get data loader imsize = cfg.TREE.BASE_SIZE * (2**(cfg.TREE.BRANCH_NUM - 1)) train_transform = transforms.Compose([ transforms.Scale(int(imsize * 76 / 64)), transforms.RandomCrop(imsize), ]) val_transform = transforms.Compose([ transforms.Scale(int(imsize * 76 / 64)), transforms.CenterCrop(imsize), ]) if args.dataset == 'bird': train_dataset = ImageTextDataset(args.data_dir, split_dir, transform=train_transform, sample_type='train') val_dataset = ImageTextDataset(args.data_dir, 'val', transform=val_transform, sample_type='val') elif args.dataset == 'coco': train_dataset = CaptionDataset(args.data_dir, split_dir, transform=train_transform, sample_type='train', coco_data_json=args.coco_data_json) val_dataset = CaptionDataset(args.data_dir, 'val', transform=val_transform, sample_type='val', coco_data_json=args.coco_data_json) else: raise NotImplementedError train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=cfg.IMAGETEXT.BATCH_SIZE, shuffle=bshuffle, num_workers=int(cfg.WORKERS)) val_dataloader = torch.utils.data.DataLoader( val_dataset, batch_size=cfg.IMAGETEXT.BATCH_SIZE, shuffle=False, num_workers=1) # define the model and optimizer if args.raw_checkpoint != '': encoder, decoder = load_raw_checkpoint(args.raw_checkpoint) else: encoder = Encoder() decoder = DecoderWithAttention( attention_dim=cfg.IMAGETEXT.ATTENTION_DIM, embed_dim=cfg.IMAGETEXT.EMBED_DIM, decoder_dim=cfg.IMAGETEXT.DECODER_DIM, vocab_size=train_dataset.n_words) # load checkpoint if cfg.IMAGETEXT.CHECKPOINT != '': outputer_val.log("load model from: {}".format( cfg.IMAGETEXT.CHECKPOINT)) encoder, decoder = load_checkpoint(encoder, decoder, cfg.IMAGETEXT.CHECKPOINT) encoder.fine_tune(False) # to cuda encoder = encoder.cuda() decoder = decoder.cuda() loss_func = torch.nn.CrossEntropyLoss() if args.eval: # eval only outputer_val.log("only eval the model...") assert cfg.IMAGETEXT.CHECKPOINT != '' val_rtn_dict, outputer_val = validate_one_epoch( 0, val_dataloader, encoder, decoder, loss_func, outputer_val) outputer_val.log("\n[valid]: {}\n".format(dict2str(val_rtn_dict))) return # define optimizer optimizer_encoder = torch.optim.Adam(encoder.parameters(), lr=cfg.IMAGETEXT.ENCODER_LR) optimizer_decoder = torch.optim.Adam(decoder.parameters(), lr=cfg.IMAGETEXT.DECODER_LR) encoder_lr_scheduler = torch.optim.lr_scheduler.StepLR( optimizer_encoder, step_size=10, gamma=cfg.IMAGETEXT.LR_GAMMA) decoder_lr_scheduler = torch.optim.lr_scheduler.StepLR( optimizer_decoder, step_size=10, gamma=cfg.IMAGETEXT.LR_GAMMA) print("train the model...") for epoch_idx in range(cfg.IMAGETEXT.EPOCH): # val_rtn_dict, outputer_val = validate_one_epoch(epoch_idx, val_dataloader, encoder, # decoder, loss_func, outputer_val) # outputer_val.log("\n[valid] epoch: {}, {}".format(epoch_idx, dict2str(val_rtn_dict))) train_rtn_dict, outputer_train = train_one_epoch( epoch_idx, train_dataloader, encoder, decoder, optimizer_encoder, optimizer_decoder, loss_func, outputer_train) # adjust lr scheduler encoder_lr_scheduler.step() decoder_lr_scheduler.step() outputer_train.log("\n[train] epoch: {}, {}\n".format( epoch_idx, dict2str(train_rtn_dict))) val_rtn_dict, outputer_val = validate_one_epoch( epoch_idx, val_dataloader, encoder, decoder, loss_func, outputer_val) outputer_val.log("\n[valid] epoch: {}, {}\n".format( epoch_idx, dict2str(val_rtn_dict))) outputer_val.save_step({ "encoder": encoder.state_dict(), "decoder": decoder.state_dict() }) outputer_val.save({ "encoder": encoder.state_dict(), "decoder": decoder.state_dict() })