def make_data(batch_size): print('Preparing data...', flush=True) if is_server(): datadir = './.data/vision/imagenet' else: # local settings datadir = '/fastwork/data/ilsvrc2012' # Setup the input pipeline _, crop = bit_hyperrule.get_resolution_from_dataset('imagenet2012') input_tx = tv.transforms.Compose([ tv.transforms.Resize((crop, crop)), tv.transforms.ToTensor(), tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) # valid_set = tv.datasets.ImageFolder(os.path.join(datadir, 'val'), input_tx) valid_set = tv.datasets.ImageNet(datadir, split='val', transform=input_tx) valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=False) return valid_set, valid_loader
def mkval(args): precrop, crop = bit_hyperrule.get_resolution_from_dataset(args.dataset) valid_tx = tv.transforms.Compose([ tv.transforms.Resize((crop, crop)), tv.transforms.ToTensor(), tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) path = args.datadir validate_csv_file = pjoin(path, 'metadata', 'validate_labels.csv') valid_set = SnakeDataset(path, is_train=False, transform=valid_tx, target_transform=None, csv_file=validate_csv_file) valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=False) return valid_set, valid_loader, valid_set.classes
def mktrainval(): precrop, crop = bit_hyperrule.get_resolution_from_dataset("cifar10") train_tx = tv.transforms.Compose([ tv.transforms.Resize((precrop, precrop)), tv.transforms.RandomCrop((crop, crop)), tv.transforms.RandomHorizontalFlip(), tv.transforms.ToTensor(), tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) val_tx = tv.transforms.Compose([ tv.transforms.Resize((crop, crop)), tv.transforms.ToTensor(), tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) train_set = tv.datasets.ImageFolder(root=r'train', transform=train_tx) valid_set = tv.datasets.ImageFolder(root=r'test', transform=val_tx) # if args.examples_per_class is not None: # # indices = fs.find_fewshot_indices(train_set, args.examples_per_class) # train_set = torch.utils.data.Subset(train_set, indices=indices) batch_size = 600 valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=4, shuffle=True, num_workers=2, pin_memory=True, drop_last=False) if batch_size <= len(train_set): train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, drop_last=False) else: train_loader = torch.utils.data.DataLoader( train_set, batch_size=batch_size, num_workers=2, pin_memory=True, sampler=torch.utils.data.RandomSampler(train_set, replacement=True, num_samples=512)) return train_set, valid_set, train_loader, valid_loader
def mkval(args): """Returns train and validation datasets.""" precrop, crop = bit_hyperrule.get_resolution_from_dataset(args.dataset) val_tx = tv.transforms.Compose([ tv.transforms.Resize((crop, crop)), tv.transforms.ToTensor(), tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) if args.dataset == "cifar10": valid_set = tv.datasets.CIFAR10(args.datadir, transform=val_tx, train=False, download=True) elif args.dataset == "cifar100": valid_set = tv.datasets.CIFAR100(args.datadir, transform=val_tx, train=False, download=True) elif args.dataset == "imagenet2012": valid_set = tv.datasets.ImageFolder(pjoin(args.datadir, "val"), val_tx) else: raise ValueError(f"Sorry, we have not spent time implementing the " f"{args.dataset} dataset in the PyTorch codebase. " f"In principle, it should be easy to add :)") if args.examples_per_class is not None: indices = fs.find_fewshot_indices(train_set, args.examples_per_class) train_set = torch.utils.data.Subset(train_set, indices=indices) micro_batch_size = args.batch_size // args.batch_split valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=micro_batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=False) return valid_set, valid_loader
def main(args): tf.io.gfile.makedirs(args.logdir) logger = bit_common.setup_logger(args) logger.info(f'Available devices: {tf.config.list_physical_devices()}') tf.io.gfile.makedirs(args.bit_pretrained_dir) bit_model_file = os.path.join(args.bit_pretrained_dir, f'{args.model}.h5') if not tf.io.gfile.exists(bit_model_file): model_url = models.KNOWN_MODELS[args.model] logger.info(f'Downloading the model from {model_url}...') tf.io.gfile.copy(model_url, bit_model_file) # Set up input pipeline dataset_info = input_pipeline.get_dataset_info(args.dataset, 'train', args.examples_per_class) # Distribute training strategy = tf.distribute.MirroredStrategy() num_devices = strategy.num_replicas_in_sync print('Number of devices: {}'.format(num_devices)) resize_size, crop_size = bit_hyperrule.get_resolution_from_dataset( args.dataset) data_train = input_pipeline.get_data( dataset=args.dataset, mode='train', repeats=None, batch_size=args.batch, resize_size=resize_size, crop_size=crop_size, examples_per_class=args.examples_per_class, examples_per_class_seed=args.examples_per_class_seed, mixup_alpha=bit_hyperrule.get_mixup(dataset_info['num_examples']), num_devices=num_devices, tfds_manual_dir=args.tfds_manual_dir) data_test = input_pipeline.get_data(dataset=args.dataset, mode='test', repeats=1, batch_size=args.batch, resize_size=resize_size, crop_size=crop_size, examples_per_class=1, examples_per_class_seed=0, mixup_alpha=None, num_devices=num_devices, tfds_manual_dir=args.tfds_manual_dir) data_train = data_train.map(lambda x: reshape_for_keras( x, batch_size=args.batch, crop_size=crop_size)) data_test = data_test.map(lambda x: reshape_for_keras( x, batch_size=args.batch, crop_size=crop_size)) with strategy.scope(): filters_factor = int(args.model[-1]) * 4 model = models.ResnetV2(num_units=models.NUM_UNITS[args.model], num_outputs=21843, filters_factor=filters_factor, name="resnet", trainable=True, dtype=tf.float32) model.build((None, None, None, 3)) logger.info(f'Loading weights...') model.load_weights(bit_model_file) logger.info(f'Weights loaded into model!') model._head = tf.keras.layers.Dense(units=dataset_info['num_classes'], use_bias=True, kernel_initializer="zeros", trainable=True, name="head/dense") lr_supports = bit_hyperrule.get_schedule(dataset_info['num_examples']) schedule_length = lr_supports[-1] # NOTE: Let's not do that unless verified necessary and we do the same # across all three codebases. # schedule_length = schedule_length * 512 / args.batch optimizer = tf.keras.optimizers.SGD(momentum=0.9) loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True) model.compile(optimizer=optimizer, loss=loss_fn, metrics=['accuracy']) logger.info(f'Fine-tuning the model...') steps_per_epoch = args.eval_every or schedule_length history = model.fit( data_train, steps_per_epoch=steps_per_epoch, epochs=schedule_length // steps_per_epoch, validation_data=data_test, # here we are only using # this data to evaluate our performance callbacks=[BiTLRSched(args.base_lr, dataset_info['num_examples'])], ) for epoch, accu in enumerate(history.history['val_accuracy']): logger.info(f'Step: {epoch * args.eval_every}, ' f'Test accuracy: {accu:0.3f}')
def mktrainval(args, logger): """Returns train and validation datasets.""" precrop, crop = bit_hyperrule.get_resolution_from_dataset(args.dataset) train_tx = tv.transforms.Compose([ tv.transforms.Resize((precrop, precrop)), tv.transforms.RandomCrop((crop, crop)), tv.transforms.RandomHorizontalFlip(), tv.transforms.ToTensor(), tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) val_tx = tv.transforms.Compose([ tv.transforms.Resize((crop, crop)), tv.transforms.ToTensor(), tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) if args.dataset == "cifar10": train_set = tv.datasets.CIFAR10(args.datadir, transform=train_tx, train=True, download=True) valid_set = tv.datasets.CIFAR10(args.datadir, transform=val_tx, train=False, download=True) elif args.dataset == "cifar100": train_set = tv.datasets.CIFAR100(args.datadir, transform=train_tx, train=True, download=True) valid_set = tv.datasets.CIFAR100(args.datadir, transform=val_tx, train=False, download=True) elif args.dataset == "imagenet2012": train_set = tv.datasets.ImageFolder(pjoin(args.datadir, "train"), train_tx) valid_set = tv.datasets.ImageFolder(pjoin(args.datadir, "val"), val_tx) else: raise ValueError(f"Sorry, we have not spent time implementing the " f"{args.dataset} dataset in the PyTorch codebase. " f"In principle, it should be easy to add :)") if args.examples_per_class is not None: logger.info( f"Looking for {args.examples_per_class} images per class...") indices = fs.find_fewshot_indices(train_set, args.examples_per_class) train_set = torch.utils.data.Subset(train_set, indices=indices) logger.info(f"Using a training set with {len(train_set)} images.") logger.info(f"Using a validation set with {len(valid_set)} images.") micro_batch_size = args.batch // args.batch_split valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=micro_batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=False) if micro_batch_size <= len(train_set): train_loader = torch.utils.data.DataLoader(train_set, batch_size=micro_batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=False) else: # In the few-shot cases, the total dataset size might be smaller than the batch-size. # In these cases, the default sampler doesn't repeat, so we need to make it do that # if we want to match the behaviour from the paper. train_loader = torch.utils.data.DataLoader( train_set, batch_size=micro_batch_size, num_workers=args.workers, pin_memory=True, sampler=torch.utils.data.RandomSampler( train_set, replacement=True, num_samples=micro_batch_size)) return train_set, valid_set, train_loader, valid_loader
def run(): aicrowd_helpers.execution_start() #MAGIC HAPPENS BELOW torch.backends.cudnn.benchmark = True device = torch.device("cuda:0") assert torch.cuda.is_available() precrop, crop = bit_hyperrule.get_resolution_from_dataset( 'snakes_dataset') # verify valid_tx = tv.transforms.Compose([ tv.transforms.Resize((crop, crop)), tv.transforms.ToTensor(), tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) given_df = pd.read_csv(AICROWD_TEST_METADATA_PATH) valid_set = SnakeDataset(AICROWD_TEST_IMAGES_PATH, is_train=False, transform=valid_tx, target_transform=None, csv_file=AICROWD_TEST_METADATA_PATH) # verify valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=32, shuffle=False, num_workers=0, pin_memory=True, drop_last=False) model = models.KNOWN_MODELS['BiT-M-R50x1']( head_size=len(VALID_SNAKE_SPECIES), zero_head=True) model = torch.nn.DataParallel(model) optim = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9) model_loc = pjoin('models', 'initial.pth.tar') checkpoint = torch.load(model_loc, map_location='cpu') model.load_state_dict(checkpoint['model']) model = model.to(device) model.eval() results = np.empty((0, 783), float) for b, (x, y) in enumerate( valid_loader): #add name to dataset, y must be some random label with torch.no_grad(): x = x.to(device, non_blocking=True) y = y.to(device, non_blocking=True) logits = model(x) softmax_op = torch.nn.Softmax(dim=1) probs = softmax_op(logits) data_to_save = probs.data.cpu().numpy() results = np.concatenate((results, data_to_save), axis=0) filenames = given_df['hashed_id'].tolist() country_prob = pd.read_csv( pjoin('metadata', 'probability_of_species_per_country.csv')) country_name = country_prob[['Species/Country']] country_dict = {name[0]: i for i, name in enumerate(country_name.values)} given_country = given_df[['country']] country_list = [] for country in given_country.values: country_list.append(str(country[0]).lower().replace( ' ', '-')) # has to be a better way adjusted_results = [] for i, result in enumerate(results): probs = result assert len(prob) == 783 try: country_now = country_list[i] country_location = country_dict[country_now] country_prob_per_this_country = country_prob.loc[[ country_location ]].values[0][1:] adjusted = country_prob_per_this_country * probs adjusted_results.append(adjusted) # verify, we need list of list except: adjusted_results.append(probs) assert len(adjusted_results) == len(results) #normalize normalized_results = adjusted_results / adjusted_results.sum(axis=1)[:, None] df = pd.DataFrame(data=normalized_results, index=filenames, columns=VALID_SNAKE_SPECIES) df.index.name = 'hashed_id' pd.to_csv(AICROWD_PREDICTIONS_OUTPUT_PATH, index=True) aicrowd_helpers.execution_success( {"predictions_output_path": AICROWD_PREDICTIONS_OUTPUT_PATH})
def mktrainval(args, logger): """Returns train and validation datasets.""" precrop, crop = bit_hyperrule.get_resolution_from_dataset(args.dataset) train_tx = tv.transforms.Compose([ tv.transforms.Resize((precrop, precrop)), tv.transforms.RandomCrop((crop, crop)), tv.transforms.RandomHorizontalFlip(), tv.transforms.ToTensor(), tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) val_tx = tv.transforms.Compose([ tv.transforms.Resize((crop, crop)), tv.transforms.ToTensor(), tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) if args.dataset == "cifar10": train_set = tv.datasets.CIFAR10(args.datadir, transform=train_tx, train=True, download=True) valid_set = tv.datasets.CIFAR10(args.datadir, transform=val_tx, train=False, download=True) elif args.dataset == "cifar100": train_set = tv.datasets.CIFAR100(args.datadir, transform=train_tx, train=True, download=True) valid_set = tv.datasets.CIFAR100(args.datadir, transform=val_tx, train=False, download=True) elif args.dataset == "imagenet2012": train_set = tv.datasets.ImageFolder(pjoin(args.datadir, "train"), train_tx) valid_set = tv.datasets.ImageFolder(pjoin(args.datadir, "val"), val_tx) # TODO: Define custom dataloading logic here for custom datasets elif args.dataset == "logo_2k": train_set = GetLoader(data_root='logo2k/Logo-2K+', data_list='logo2k/train.txt', label_dict='logo2k/logo2k_labeldict.pkl', transform=train_tx) valid_set = GetLoader(data_root='logo2k/Logo-2K+', data_list='logo2k/test.txt', label_dict='logo2k/logo2k_labeldict.pkl', transform=val_tx) elif args.dataset == "targetlist": train_set = GetLoader(data_root='../../phishpedia/expand_targetlist', data_list='../train_targets.txt', label_dict='../target_dict.json', transform=train_tx) valid_set = GetLoader(data_root='../../phishpedia/expand_targetlist', data_list='../test_targets.txt', label_dict='../target_dict.json', transform=val_tx) logger.info("Using a training set with {} images.".format(len(train_set))) logger.info("Using a validation set with {} images.".format( len(valid_set))) logger.info("Num of classes: {}".format(len(valid_set.classes))) micro_batch_size = args.batch // args.batch_split valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=micro_batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=False) if micro_batch_size <= len(train_set): train_loader = torch.utils.data.DataLoader(train_set, batch_size=micro_batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=False) else: # In the few-shot cases, the total dataset size might be smaller than the batch-size. # In these cases, the default sampler doesn't repeat, so we need to make it do that # if we want to match the behaviour from the paper. train_loader = torch.utils.data.DataLoader( train_set, batch_size=micro_batch_size, num_workers=args.workers, pin_memory=True, sampler=torch.utils.data.RandomSampler( train_set, replacement=True, num_samples=micro_batch_size)) return train_set, valid_set, train_loader, valid_loader
def get_data_loader(args): if args.output_dir: utils.mkdir(args.output_dir) utils.init_distributed_mode(args) torch.backends.cudnn.benchmark = True if args.dataset == "imagenet": train_dir = os.path.join(args.data_path, 'train') val_dir = os.path.join(args.data_path, 'val') dataset, dataset_test, train_sampler, test_sampler = load_data( train_dir, val_dir, args.cache_dataset, args.distributed) data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True) data_loader_test = torch.utils.data.DataLoader( dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True) elif args.dataset == "cifar10": if args.model != "big_transfer": mean = [0.4914, 0.4822, 0.4465] std = [0.2023, 0.1994, 0.2010] transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean, std) ]) dataset = CIFAR10(root=args.data_path, train=True, transform=transform_train) data_loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=4, shuffle=True, drop_last=True, pin_memory=True) transform_val = transforms.Compose( [transforms.ToTensor(), transforms.Normalize(mean, std)]) dataset = CIFAR10(root=args.data_path, train=False, transform=transform_val) data_loader_test = DataLoader(dataset, batch_size=args.batch_size, num_workers=4, pin_memory=True) else: precrop, crop = bit_hyperrule.get_resolution_from_dataset( args.dataset) train_tx = transforms.Compose([ transforms.Resize((precrop, precrop)), transforms.RandomCrop((crop, crop)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) val_tx = transforms.Compose([ transforms.Resize((crop, crop)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) dataset = CIFAR10(root=args.data_path, train=True, transform=train_tx) data_loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=4, shuffle=True, drop_last=True, pin_memory=True) dataset = CIFAR10(root=args.data_path, train=False, transform=val_tx) data_loader_test = DataLoader(dataset, batch_size=args.batch_size, num_workers=4, pin_memory=True) return data_loader, data_loader_test
def mktrainval(args, logger): """Returns train and validation datasets.""" precrop, crop = bit_hyperrule.get_resolution_from_dataset(args.dataset) train_tx = tv.transforms.Compose([ tv.transforms.Resize((precrop, precrop)), tv.transforms.RandomCrop((crop, crop)), tv.transforms.RandomHorizontalFlip(), tv.transforms.RandomRotation(90), tv.transforms.ColorJitter(), tv.transforms.RandomAffine(0, scale=(1.0, 2.0), shear=20), tv.transforms.ToTensor(), tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) val_tx = tv.transforms.Compose([ tv.transforms.Resize((crop, crop)), tv.transforms.ToTensor(), tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) path = args.datadir train_csv_file = pjoin(path, 'metadata', 'train_labels.csv') validate_csv_file = pjoin(path, 'metadata', 'validate_labels.csv') train_set = SnakeDataset(path, is_train=True, transform=train_tx, target_transform=None, csv_file=train_csv_file) valid_set = SnakeDataset(path, is_train=False, transform=val_tx, target_transform=None, csv_file=validate_csv_file) if args.examples_per_class is not None: logger.info( f"Looking for {args.examples_per_class} images per class...") indices = fs.find_fewshot_indices(train_set, args.examples_per_class) train_set = torch.utils.data.Subset(train_set, indices=indices) logger.info(f"Using a training set with {len(train_set)} images.") logger.info(f"Using a validation set with {len(valid_set)} images.") micro_batch_size = args.batch // args.batch_split valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=micro_batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=False) if micro_batch_size <= len(train_set): train_loader = torch.utils.data.DataLoader(train_set, batch_size=micro_batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=False) else: # In the few-shot cases, the total dataset size might be smaller than the batch-size. # In these cases, the default sampler doesn't repeat, so we need to make it do that # if we want to match the behaviour from the paper. train_loader = torch.utils.data.DataLoader( train_set, batch_size=micro_batch_size, num_workers=args.workers, pin_memory=True, sampler=torch.utils.data.RandomSampler( train_set, replacement=True, num_samples=micro_batch_size)) return train_set, valid_set, train_loader, valid_loader
def main(args): logger = bit_common.setup_logger(args) logger.info(f'Available devices: {jax.devices()}') model = models.KNOWN_MODELS[args.model] # Load weigths of a BiT model bit_model_file = os.path.join(args.bit_pretrained_dir, f'{args.model}.npz') if not os.path.exists(bit_model_file): raise FileNotFoundError( f'Model file is not found in "{args.bit_pretrained_dir}" directory.' ) with open(bit_model_file, 'rb') as f: params_tf = np.load(f) params_tf = dict(zip(params_tf.keys(), params_tf.values())) resize_size, crop_size = bit_hyperrule.get_resolution_from_dataset( args.dataset) # Setup input pipeline dataset_info = input_pipeline.get_dataset_info(args.dataset, 'train', args.examples_per_class) data_train = input_pipeline.get_data( dataset=args.dataset, mode='train', repeats=None, batch_size=args.batch, resize_size=resize_size, crop_size=crop_size, examples_per_class=args.examples_per_class, examples_per_class_seed=args.examples_per_class_seed, mixup_alpha=bit_hyperrule.get_mixup(dataset_info['num_examples']), num_devices=jax.local_device_count(), tfds_manual_dir=args.tfds_manual_dir) logger.info(data_train) data_test = input_pipeline.get_data(dataset=args.dataset, mode='test', repeats=1, batch_size=args.batch_eval, resize_size=resize_size, crop_size=crop_size, examples_per_class=None, examples_per_class_seed=0, mixup_alpha=None, num_devices=jax.local_device_count(), tfds_manual_dir=args.tfds_manual_dir) logger.info(data_test) # Build ResNet architecture ResNet = model.partial(num_classes=dataset_info['num_classes']) _, params = ResNet.init_by_shape( jax.random.PRNGKey(0), [([1, crop_size, crop_size, 3], jnp.float32)]) resnet_fn = ResNet.call # pmap replicates the models over all GPUs resnet_fn_repl = jax.pmap(ResNet.call) def cross_entropy_loss(*, logits, labels): logp = jax.nn.log_softmax(logits) return -jnp.mean(jnp.sum(logp * labels, axis=1)) def loss_fn(params, images, labels): logits = resnet_fn(params, images) return cross_entropy_loss(logits=logits, labels=labels) # Update step, replicated over all GPUs @partial(jax.pmap, axis_name='batch') def update_fn(opt, lr, batch): l, g = jax.value_and_grad(loss_fn)(opt.target, batch['image'], batch['label']) g = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='batch'), g) opt = opt.apply_gradient(g, learning_rate=lr) return opt # In-place update of randomly initialized weights by BiT weigths tf2jax.transform_params(params, params_tf, num_classes=dataset_info['num_classes']) # Create optimizer and replicate it over all GPUs opt = optim.Momentum(beta=0.9).create(params) opt_repl = flax_utils.replicate(opt) # Delete referenes to the objects that are not needed anymore del opt del params total_steps = bit_hyperrule.get_schedule(dataset_info['num_examples'])[-1] # Run training loop for step, batch in zip(range(1, total_steps + 1), data_train.as_numpy_iterator()): lr = bit_hyperrule.get_lr(step - 1, dataset_info['num_examples'], args.base_lr) opt_repl = update_fn(opt_repl, flax_utils.replicate(lr), batch) # Run eval step if ((args.eval_every and step % args.eval_every == 0) or (step == total_steps)): accuracy_test = np.mean([ c for batch in data_test.as_numpy_iterator() for c in (np.argmax( resnet_fn_repl(opt_repl.target, batch['image']), axis=2) == np.argmax(batch['label'], axis=2)).ravel() ]) logger.info(f'Step: {step}, ' f'learning rate: {lr:.07f}, ' f'Test accuracy: {accuracy_test:0.3f}')
def select_worst_images(args, model, full_train_loader, device): print("Selecting images for next epoch training...") model.eval() gts = [] paths = [] losses = [] micro_batch_size = args.batch_size // args.batch_split precrop, crop = bit_hyperrule.get_resolution_from_dataset(args.dataset) if args.input_channels == 3: train_tx = tv.transforms.Compose([ tv.transforms.Resize((precrop, precrop)), tv.transforms.RandomCrop((crop, crop)), tv.transforms.RandomHorizontalFlip(), tv.transforms.ToTensor(), tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) elif args.input_channels == 2: train_tx = tv.transforms.Compose([ tv.transforms.Resize((precrop, precrop)), tv.transforms.RandomCrop((crop, crop)), tv.transforms.RandomHorizontalFlip(), tv.transforms.ToTensor(), tv.transforms.Normalize((0.5, 0.5), (0.5, 0.5)), ]) elif args.input_channels == 1: train_tx = tv.transforms.Compose([ tv.transforms.Resize((precrop, precrop)), tv.transforms.RandomCrop((crop, crop)), tv.transforms.RandomHorizontalFlip(), tv.transforms.ToTensor(), tv.transforms.Normalize((0.5), (0.5)), ]) pbar = enumerate(full_train_loader) pbar = tqdm.tqdm(pbar, total=len(full_train_loader)) for b, (path, x, y) in pbar: with torch.no_grad(): x = x.to(device, non_blocking=True) y = y.to(device, non_blocking=True) # compute output, measure accuracy and record loss. logits = model(x) paths.extend(path) gts.extend(y.cpu().numpy()) c = torch.nn.CrossEntropyLoss(reduction='none')(logits, y) losses.extend( c.cpu().numpy().tolist()) # Also ensures a sync point. # measure elapsed time end = time.time() gts = np.array(gts) losses = np.array(losses) losses[np.argsort(losses)[int(losses.shape[0] * (1.0 - args.noise)):]] = 0.0 # #paths_ = np.array(paths)[np.where(losses > np.median(losses))[0]] #gts_ = gts[np.where(losses > np.median(losses))[0]] selection_idx = int(args.data_fraction * losses.shape[0]) paths_ = np.array(paths)[np.argsort(losses)[-selection_idx:]] gts_ = gts[np.argsort(losses)[-selection_idx:]] smart_train_set = ImageFolder(paths_, gts_, train_tx, crop) smart_train_loader = torch.utils.data.DataLoader( smart_train_set, batch_size=micro_batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=False) return smart_train_set, smart_train_loader
def _mktrainval(args, logger): """Returns train and validation datasets.""" precrop, crop = bit_hyperrule.get_resolution_from_dataset(args.dataset) if args.test_run: # save memory precrop, crop = 64, 56 train_tx = tv.transforms.Compose([ tv.transforms.Resize((precrop, precrop)), tv.transforms.RandomCrop((crop, crop)), tv.transforms.RandomHorizontalFlip(), tv.transforms.ToTensor(), tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) val_tx = tv.transforms.Compose([ tv.transforms.Resize((crop, crop)), tv.transforms.ToTensor(), tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) collate_fn = None n_train = None micro_batch_size = args.batch // args.batch_split if args.dataset == "cifar10": train_set = tv.datasets.CIFAR10(args.datadir, transform=train_tx, train=True, download=True) valid_set = tv.datasets.CIFAR10(args.datadir, transform=val_tx, train=False, download=True) elif args.dataset == "cifar100": train_set = tv.datasets.CIFAR100(args.datadir, transform=train_tx, train=True, download=True) valid_set = tv.datasets.CIFAR100(args.datadir, transform=val_tx, train=False, download=True) elif args.dataset == "imagenet2012": train_set = tv.datasets.ImageFolder(pjoin(args.datadir, "train"), transform=train_tx) valid_set = tv.datasets.ImageFolder(pjoin(args.datadir, "val"), transform=val_tx) elif args.dataset.startswith('objectnet') or args.dataset.startswith('imageneta'): # objectnet and objectnet_bbox and objectnet_no_bbox identifier = 'objectnet' if args.dataset.startswith('objectnet') else 'imageneta' valid_set = tv.datasets.ImageFolder(f"../datasets/{identifier}/", transform=val_tx) if args.inpaint == 'none': if args.dataset == 'objectnet' or args.dataset == 'imageneta': train_set = tv.datasets.ImageFolder(pjoin(args.datadir, f"train_{args.dataset}"), transform=train_tx) else: # For only images with or w/o bounding box train_bbox_file = '../datasets/imagenet/LOC_train_solution_size.csv' df = pd.read_csv(train_bbox_file) filenames = set(df[df.bbox_ratio <= args.bbox_max_ratio].ImageId) if args.dataset == f"{identifier}_no_bbox": is_valid_file = lambda path: os.path.basename(path).split('.')[0] not in filenames elif args.dataset == f"{identifier}_bbox": is_valid_file = lambda path: os.path.basename(path).split('.')[0] in filenames else: raise NotImplementedError() train_set = tv.datasets.ImageFolder( pjoin(args.datadir, f"train_{identifier}"), is_valid_file=is_valid_file, transform=train_tx) else: # do inpainting train_tx = tv.transforms.Compose([ data_utils.Resize((precrop, precrop)), data_utils.RandomCrop((crop, crop)), data_utils.RandomHorizontalFlip(), data_utils.ToTensor(), data_utils.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) train_set = ImagenetBoundingBoxFolder( root=f"../datasets/imagenet/train_{identifier}", bbox_file='../datasets/imagenet/LOC_train_solution.csv', transform=train_tx) collate_fn = bbox_collate n_train = len(train_set) * 2 micro_batch_size //= 2 else: raise ValueError(f"Sorry, we have not spent time implementing the " f"{args.dataset} dataset in the PyTorch codebase. " f"In principle, it should be easy to add :)") if args.examples_per_class is not None: logger.info(f"Looking for {args.examples_per_class} images per class...") indices = fs.find_fewshot_indices(train_set, args.examples_per_class) train_set = torch.utils.data.Subset(train_set, indices=indices) logger.info(f"Using a training set with {len(train_set)} images.") logger.info(f"Using a validation set with {len(valid_set)} images.") valid_loader = torch.utils.data.DataLoader( valid_set, batch_size=micro_batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=False) if micro_batch_size <= len(train_set): train_loader = torch.utils.data.DataLoader( train_set, batch_size=micro_batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=False, collate_fn=collate_fn) else: # In the few-shot cases, the total dataset size might be smaller than the batch-size. # In these cases, the default sampler doesn't repeat, so we need to make it do that # if we want to match the behaviour from the paper. train_loader = torch.utils.data.DataLoader( train_set, batch_size=micro_batch_size, num_workers=args.workers, pin_memory=True, sampler=torch.utils.data.RandomSampler(train_set, replacement=True, num_samples=micro_batch_size), collate_fn=collate_fn) if n_train is None: n_train = len(train_set) return n_train, len(valid_set.classes), train_loader, valid_loader