def get_dataset_stats(args): args = copy.deepcopy(args) modify_args(args) transformers = transformer_factory.get_transformers( args.image_transformers, args.tensor_transformers, args) train, _, _ = dataset_factory.get_dataset(args, transformers, []) data_loader = DataLoader(train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=False, collate_fn=ignore_None_collate) means, stds = {0: [], 1: [], 2: []}, {0: [], 1: [], 2: []} for batch in tqdm(data_loader): tensor = batch['x'] if args.cuda: tensor = tensor.cuda() for channel in range(3): tensor_chan = tensor[:, channel] means[channel].append(torch.mean(tensor_chan)) stds[channel].append(torch.std(tensor_chan)) means = [torch.mean(torch.Tensor(means[channel])) for channel in range(3)] stds = [torch.mean(torch.Tensor(stds[channel])) for channel in range(3)] return means, stds
def get_image_to_right_side(args): args = copy.deepcopy(args) modify_args(args) transformers = transformer_factory.get_transformers(args.image_transformers, args.tensor_transformers, args) train, dev, test = dataset_factory.get_dataset(args, transformers, transformers) image_to_side = {} for dataset in [train,dev,test]: data_loader = DataLoader( dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, drop_last=False, pin_memory=True, collate_fn=ignore_None_collate) for batch in tqdm(data_loader): img = batch['x'] paths = batch['path'] if args.cuda: img = img.cuda() B, C, H, W = img.size() left_half = img[:, :, :, :W//2].contiguous().view(B,-1) right_half = img[:, :, :, W//2:].contiguous().view(B,-1) is_right_aligned = right_half.sum(dim=-1) > left_half.sum(dim=-1) for indx, path in enumerate(paths): image_to_side[path] = bool(is_right_aligned[indx]) return image_to_side
commit.hexsha, commit.message, commit.author, commit.committed_date)) if args.get_dataset_stats: print("\nComputing image mean and std...") args.img_mean, args.img_std = get_dataset_stats(args) print('Mean: {}'.format(args.img_mean)) print('Std: {}'.format(args.img_std)) print("\nLoading data-augmentation scheme...") transformers = transformer_factory.get_transformers( args.image_transformers, args.tensor_transformers, args) test_transformers = transformer_factory.get_transformers( args.test_image_transformers, args.test_tensor_transformers, args) # Load dataset and add dataset specific information to args print("\nLoading data...") train_data, dev_data, test_data = dataset_factory.get_dataset(args, transformers, test_transformers) # Load model and add model specific information to args if args.snapshot is None: model = model_factory.get_model(args) else: model = model_factory.load_model(args.snapshot, args) if args.replace_snapshot_pool: non_trained_model = model_factory.get_model(args) model._model.pool = non_trained_model._model.pool model._model.args = non_trained_model._model.args print(model) # Load run parameters if resuming that run. args.model_path = state.get_model_path(args) print('Trained model will be saved to [%s]' % args.model_path)