コード例 #1
0
ファイル: analyze.py プロジェクト: nikopj/DGCN
def test(args, model, noise_std=25, device=torch.device('cpu')):
    loader = data.getDataLoaders(**args['train']['loaders'])['test']
    model.eval()
    t = tqdm(iter(loader), desc=f"TEST", dynamic_ncols=True)
    psnr = 0
    for itern, batch in enumerate(t):
        batch = batch.to(device)
        noisy_batch = utils.awgn(batch, noise_std)
        with torch.no_grad():
            output = model(noisy_batch)
        mse = torch.mean((batch - output)**2).item()
        psnr = psnr - 10 * np.log10(mse)
    psnr = psnr / (itern + 1)
    print(f"Test PSNR = {psnr:.2f} dB")
    with open(os.path.join(args['paths']['save'], f'test.psnr'),
              'a') as psnr_file:
        psnr_file.write(f'{psnr}  ')
コード例 #2
0
def main():
    args = set_args()
    setattr(args, 'model_time', strftime('%H:%M:%S', gmtime()))
    setattr(args, 'class_size', 4)

    # loading EmoContext data
    print("loading data")
    train_dataloader, valid_dataloader, num_train_examples = getDataLoaders(
        args)

    best_model, max_dev_f1 = train(args, train_dataloader, valid_dataloader,
                                   num_train_examples)

    if not os.path.exists('saved_models'):
        os.makedirs('saved_models')
    torch.save(best_model,
               f'saved_models/BERT_{args.model_time}_{max_dev_f1}.pt')

    print('training finished!')
コード例 #3
0
    def test_getDataLoaders(self):
        train_ds, val_ds = getDatasets(
            'images/labels.csv', 'images/12_images', 'fname', 'is_mono',
            train_size=1.0, train_transform=transforms.ToTensor()
        )

        train_dl, val_dl = getDataLoaders(
            train_ds, val_ds, batch_size=1, num_workers=2,
            use_weighted_sampling=True
        )

        # Verify that weighted sampling is working.
        counts = [0, 0]

        for i in range(100):
            for samp in train_dl:
                label = int(samp[1][0])
                counts[label] += 1

        counts[0] /= 100 * 12
        counts[1] /= 100 * 12
        print(counts)

        self.assertLess(abs(0.5 - counts[0]), 0.03)
コード例 #4
0
ファイル: single_split_train.py プロジェクト: stuckyb/inat_cv
        fout.write('{0}: {1}\n'.format(key, argsdict[key]))

n_gpus = args.n_gpus
if n_gpus < 0:
    n_gpus = torch.cuda.device_count()

train_data, val_data = getDatasets(args.labels_csv,
                                   args.images,
                                   args.fnames_col,
                                   args.labels_col,
                                   train_size=args.train_split,
                                   train_transform=train_transform,
                                   val_transform=val_transform,
                                   rng=rng)
trainloader, valloader = getDataLoaders(train_data,
                                        val_data,
                                        batch_size=args.batch_size,
                                        use_weighted_sampling=True)

if args.model_wts != '':
    print(f'Loading model weights from {args.model_wts}...')
    model = ENModel.load_from_checkpoint(args.model_wts,
                                         lr=args.learning_rate,
                                         n_classes=len(
                                             train_data.dataset.classes))
else:
    model = ENModel(args.learning_rate, len(train_data.dataset.classes))

if args.top_only:
    model.setTrainTopOnly()

tb_logger = pl_loggers.TensorBoardLogger(outputdir, exp_name)
コード例 #5
0
ファイル: cv_train.py プロジェクト: stuckyb/inat_cv
    #print(len(train_idx), len(valid_idx))

    fold_folder = outpath / ('fold_' + str(loop_count))
    fold_folder.mkdir()
    print(
        f'\n#\n# Cross-validation fold {loop_count}; saving results to '
        f'{fold_folder}.\n#'
    )
    
    train_data, val_data = getDatasets(
        args.labels_csv, args.images, args.fnames_col, args.labels_col,
        train_idx=train_idx, valid_idx=valid_idx,
        train_transform=train_transform, val_transform=val_transform, rng=rng
    )
    trainloader, valloader = getDataLoaders(
        train_data, val_data, batch_size=args.batch_size
    )

    if args.model_wts != '':
        print(f'Loading model weights from {args.model_wts}...')
        model = ENModel.load_from_checkpoint(
            args.model_wts, lr=args.learning_rate,
            n_classes=len(train_data.dataset.classes)
        )
    else:
        model = ENModel(args.learning_rate, len(train_data.dataset.classes))

    if args.top_only:
        model.setTrainTopOnly()

    tb_logger = pl_loggers.TensorBoardLogger(