Пример #1
0
        args.trained_models_fn + '_' + args.geometric_model + '_grid_loss' +
        args.feature_extraction_cnn + '.pth.tar')

best_test_loss = float("inf")

print('Starting training...')

for epoch in range(1, args.num_epochs + 1):
    train_loss = train(epoch,
                       model,
                       loss,
                       optimizer,
                       dataloader,
                       pair_generation_tnf,
                       log_interval=100)
    test_loss = test(model, loss, dataloader_test, pair_generation_tnf)

    # remember best loss
    is_best = test_loss < best_test_loss
    best_test_loss = min(test_loss, best_test_loss)
    save_checkpoint(
        {
            'epoch': epoch + 1,
            'args': args,
            'state_dict': model.state_dict(),
            'best_test_loss': best_test_loss,
            'optimizer': optimizer.state_dict(),
        }, is_best, checkpoint_name)

print('Done!')
Пример #2
0
        # Load optimizer state dict
        optimizer.load_state_dict(checkpoint['optimizer'])
        # Load epoch information
        start_epoch = checkpoint['epoch']
        print("Reloading from--[%s]" % args.load_model)

    for epoch in range(start_epoch, args.num_epochs + 1):
        # Call train, test function
        train_loss = train(epoch,
                           model,
                           loss,
                           optimizer,
                           dataloader_train,
                           use_cuda,
                           log_interval=100)
        test_acc = test(model, dataloader_test, len(dataset_test), use_cuda)

        checkpoint_name = os.path.join(
            args.trained_models_dir,
            args.model_type + '_epoch_' + str(epoch) + '.pth.tar')

        # Save checkpoint
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': args,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, checkpoint_name)

    print('Done!')
Пример #3
0
def start_train(training_path,test_image_path,load_from,out_path,vis_env,paper_affine_generator = False,
                random_seed=666,log_interval=100,multi_gpu=True,use_cuda=True):

    init_seeds(random_seed+random.randint(0,10000))

    device,local_rank = torch_util.select_device(multi_process =multi_gpu,apex=mixed_precision)

    # args.batch_size = args.batch_size * torch.cuda.device_count()
    args.batch_size = 16
    args.lr_scheduler = True
    draw_test_loss = False
    print(args.batch_size)


    print("创建模型中")
    model = CNNRegistration(use_cuda=use_cuda)

    model = model.to(device)

    # 优化器 和scheduler
    optimizer = optim.Adam(model.FeatureRegression.parameters(), lr=args.lr)

    if args.lr_scheduler:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                               T_max=args.lr_max_iter,
                                                               eta_min=1e-7)
    else:
        scheduler = False

    print("加载权重")
    minium_loss,saved_epoch= load_checkpoint(model,optimizer,load_from,0)

    # Mixed precision training https://github.com/NVIDIA/apex
    if mixed_precision:
        model,optimizer = amp.initialize(model,optimizer,opt_level='01',verbosity=0)

    if multi_gpu:
        model = nn.DataParallel(model)

    loss = NTGLoss()
    pair_generator = RandomTnsPair(use_cuda=use_cuda)
    gridGen = AffineGridGen()
    vis = VisdomHelper(env_name=vis_env)

    print("创建dataloader")
    RandomTnsDataset = RandomTnsData(training_path, cache_images=False,paper_affine_generator = paper_affine_generator,
                                     transform=NormalizeImageDict(["image"]))
    train_dataloader = DataLoader(RandomTnsDataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)

    if draw_test_loss:
        testDataset = RandomTnsData(test_image_path, cache_images=False, paper_affine_generator=paper_affine_generator,
                                     transform=NormalizeImageDict(["image"]))
        test_dataloader = DataLoader(testDataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=False)

    print('Starting training...')

    for epoch in range(saved_epoch, args.num_epochs):
        start_time = time.time()

        train_loss = train(epoch, model, loss, optimizer, train_dataloader, pair_generator, gridGen, vis,
                           use_cuda=use_cuda, log_interval=log_interval,scheduler = scheduler)

        if draw_test_loss:
            test_loss = test(model,loss,test_dataloader,pair_generator,gridGen,use_cuda=use_cuda)
            vis.drawBothLoss(epoch,train_loss,test_loss,'loss_table')
        else:
            vis.drawLoss(epoch,train_loss)

        end_time = time.time()
        print("epoch:", str(end_time - start_time),'秒')

        is_best = train_loss < minium_loss
        minium_loss = min(train_loss, minium_loss)

        state_dict = model.module.state_dict() if multi_gpu else model.state_dict()
        save_checkpoint({
            'epoch': epoch + 1,
            'args': args,
            #'state_dict': model.state_dict(),
            'state_dict': state_dict,
            'minium_loss': minium_loss,
            'model_loss':train_loss,
            'optimizer': optimizer.state_dict(),
        }, is_best, out_path)
Пример #4
0
epochArray = np.zeros(args.num_epochs)
trainLossArray = np.zeros(args.num_epochs)
testLossArray = np.zeros(args.num_epochs)

for epoch in range(1, args.num_epochs + 1):
    train_loss = train(epoch,
                       model,
                       loss,
                       optimizer,
                       dataloader,
                       pair_generation_tnf,
                       log_interval=10)
    test_loss = test(model,
                     loss,
                     dataloader_test,
                     pair_generation_tnf,
                     use_cuda=use_cuda,
                     geometric_model=args.geometric_model)

    scheduler.step()

    epochArray[epoch - 1] = epoch
    trainLossArray[epoch - 1] = train_loss
    testLossArray[epoch - 1] = test_loss

    # remember best loss
    is_best = test_loss < best_test_loss
    best_test_loss = min(test_loss, best_test_loss)
    save_checkpoint(
        {
            'epoch': epoch + 1,