def on_train_batch_begin(self, batch, logs=None): lr = bit_hyperrule.get_lr(self.step, self.num_samples, self.base_lr) tf.keras.backend.set_value(self.model.optimizer.lr, lr) self.step += 1
def main(args): logger = bit_common.setup_logger(args) # Lets cuDNN benchmark conv implementations and choose the fastest. # Only good if sizes stay the same within the main loop! torch.backends.cudnn.benchmark = True device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") logger.info(f"Going to train on {device}") train_set, valid_set, train_loader, valid_loader = mktrainval(args, logger) logger.info(f"Loading model from {args.model}.npz") model = models.KNOWN_MODELS[args.model](head_size=len(valid_set.classes), zero_head=True) model.load_from(np.load(f"{args.model}.npz")) logger.info("Moving model onto all GPUs") model = torch.nn.DataParallel(model) # Optionally resume from a checkpoint. # Load it to CPU first as we'll move the model to GPU later. # This way, we save a little bit of GPU memory when loading. step = 0 # Note: no weight-decay! optim = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9) # Resume fine-tuning if we find a saved model. savename = pjoin(args.logdir, args.name, "bit.pth.tar") try: logger.info(f"Model will be saved in '{savename}'") checkpoint = torch.load(savename, map_location="cpu") logger.info(f"Found saved model to resume from at '{savename}'") step = checkpoint["step"] model.load_state_dict(checkpoint["model"]) optim.load_state_dict(checkpoint["optim"]) logger.info(f"Resumed at step {step}") except FileNotFoundError: logger.info("Fine-tuning from BiT") model = model.to(device) optim.zero_grad() model.train() mixup = bit_hyperrule.get_mixup(len(train_set)) cri = torch.nn.CrossEntropyLoss().to(device) logger.info("Starting training!") chrono = lb.Chrono() accum_steps = 0 mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1 end = time.time() with lb.Uninterrupt() as u: for x, y in recycle(train_loader): # measure data loading time, which is spent in the `for` statement. chrono._done("load", time.time() - end) if u.interrupted: break # Schedule sending to GPU(s) x = x.to(device, non_blocking=True) y = y.to(device, non_blocking=True) # Update learning-rate, including stop training if over. lr = bit_hyperrule.get_lr(step, len(train_set), args.base_lr) if lr is None: break for param_group in optim.param_groups: param_group["lr"] = lr if mixup > 0.0: x, y_a, y_b = mixup_data(x, y, mixup_l) # compute output with chrono.measure("fprop"): logits = model(x) if mixup > 0.0: c = mixup_criterion(cri, logits, y_a, y_b, mixup_l) else: c = cri(logits, y) c_num = float( c.data.cpu().numpy()) # Also ensures a sync point. # Accumulate grads with chrono.measure("grads"): (c / args.batch_split).backward() accum_steps += 1 accstep = f" ({accum_steps}/{args.batch_split})" if args.batch_split > 1 else "" logger.info( f"[step {step}{accstep}]: loss={c_num:.5f} (lr={lr:.1e})") # pylint: disable=logging-format-interpolation logger.flush() # Update params if accum_steps == args.batch_split: with chrono.measure("update"): optim.step() optim.zero_grad() step += 1 accum_steps = 0 # Sample new mixup ratio for next batch mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1 # Run evaluation and save the model. if args.eval_every and step % args.eval_every == 0: run_eval(model, valid_loader, device, chrono, logger, step) if args.save: torch.save( { "step": step, "model": model.state_dict(), "optim": optim.state_dict(), }, savename) end = time.time() # Final eval at end of training. run_eval(model, valid_loader, device, chrono, logger, step='end') logger.info(f"Timings:\n{chrono}")
def main(args): logger = bit_common.setup_logger(args) # Lets cuDNN benchmark conv implementations and choose the fastest. # Only good if sizes stay the same within the main loop! torch.backends.cudnn.benchmark = True device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") logger.info("Going to train on {}".format(device)) train_set, valid_set, train_loader, valid_loader = mktrainval(args, logger) logger.info("Loading model from {}.npz".format(args.model)) model = models.KNOWN_MODELS[args.model](head_size=len(valid_set.classes), zero_head=True) model.load_from(np.load("{}.npz".format(args.model))) logger.info("Moving model onto all GPUs") model = torch.nn.DataParallel(model) # Note: no weight-decay! optim = torch.optim.SGD(model.parameters(), lr=args.base_lr, momentum=0.9) # Optionally resume from a checkpoint. # Load it to CPU first as we'll move the model to GPU later. # This way, we save a little bit of GPU memory when loading. step = 0 # If pretrained weights are specified if args.weights_path: logger.info("Loading weights from {}".format(args.weights_path)) checkpoint = torch.load(args.weights_path, map_location="cpu") # New task might have different classes; remove the pretrained classifier weights del checkpoint['model']['module.head.conv.weight'] del checkpoint['model']['module.head.conv.bias'] model.load_state_dict(checkpoint["model"], strict=False) # Resume fine-tuning if we find a saved model. savename = pjoin(args.logdir, args.name, "bit.pth.tar") try: logger.info("Model will be saved in '{}'".format(savename)) checkpoint = torch.load(savename, map_location="cpu") logger.info( "Found saved model to resume from at '{}'".format(savename)) step = checkpoint["step"] model.load_state_dict(checkpoint["model"]) optim.load_state_dict(checkpoint["optim"]) logger.info("Resumed at step {}".format(step)) except FileNotFoundError: logger.info("Fine-tuning from BiT") model = model.to(device) # Send to GPU optimizer_to(optim, device) optim.zero_grad() model.train() mixup = bit_hyperrule.get_mixup(len(train_set)) cri = torch.nn.CrossEntropyLoss().to(device) logger.info("Starting training!") chrono = lb.Chrono() mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1 end = time.time() with lb.Uninterrupt() as u: for x, y in recycle(train_loader): # measure data loading time, which is spent in the `for` statement. chrono._done("load", time.time() - end) if u.interrupted: break # Schedule sending to GPU(s) x = x.to(device, non_blocking=True) y = y.to(device, non_blocking=True) # Update learning-rate, including stop training if over. lr = bit_hyperrule.get_lr(step, len(train_set), args.base_lr) if lr is None: break for param_group in optim.param_groups: param_group["lr"] = lr if mixup > 0.0: x, y_a, y_b = mixup_data(x, y, mixup_l) # compute output logits = model(x) if mixup > 0.0: c = mixup_criterion(cri, logits, y_a, y_b, mixup_l) else: c = cri(logits, y) c_num = float(c.data.cpu().numpy()) # Also ensures a sync point. # Accumulate grads (c / args.batch_split).backward() logger.info("[step {}]: loss={:.5f} (lr={:.1e})".format( step, c_num, lr)) logger.flush() # Update params optim.step() optim.zero_grad() step += 1 # Sample new mixup ratio for next batch mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1 end = time.time() if step % 50 == 0: torch.save( { "step": step, "model": model.state_dict(), "optim": optim.state_dict(), }, savename) # Final eval at end of training. run_eval(model, valid_loader, device, chrono, logger, step='end') logger.info("Timings:\n{}".format(chrono))
def main(args): logger = bit_common.setup_logger(args) logger.info(f'Available devices: {jax.devices()}') model = models.KNOWN_MODELS[args.model] # Load weigths of a BiT model bit_model_file = os.path.join(args.bit_pretrained_dir, f'{args.model}.npz') if not os.path.exists(bit_model_file): raise FileNotFoundError( f'Model file is not found in "{args.bit_pretrained_dir}" directory.' ) with open(bit_model_file, 'rb') as f: params_tf = np.load(f) params_tf = dict(zip(params_tf.keys(), params_tf.values())) resize_size, crop_size = bit_hyperrule.get_resolution_from_dataset( args.dataset) # Setup input pipeline dataset_info = input_pipeline.get_dataset_info(args.dataset, 'train', args.examples_per_class) data_train = input_pipeline.get_data( dataset=args.dataset, mode='train', repeats=None, batch_size=args.batch, resize_size=resize_size, crop_size=crop_size, examples_per_class=args.examples_per_class, examples_per_class_seed=args.examples_per_class_seed, mixup_alpha=bit_hyperrule.get_mixup(dataset_info['num_examples']), num_devices=jax.local_device_count(), tfds_manual_dir=args.tfds_manual_dir) logger.info(data_train) data_test = input_pipeline.get_data(dataset=args.dataset, mode='test', repeats=1, batch_size=args.batch_eval, resize_size=resize_size, crop_size=crop_size, examples_per_class=None, examples_per_class_seed=0, mixup_alpha=None, num_devices=jax.local_device_count(), tfds_manual_dir=args.tfds_manual_dir) logger.info(data_test) # Build ResNet architecture ResNet = model.partial(num_classes=dataset_info['num_classes']) _, params = ResNet.init_by_shape( jax.random.PRNGKey(0), [([1, crop_size, crop_size, 3], jnp.float32)]) resnet_fn = ResNet.call # pmap replicates the models over all GPUs resnet_fn_repl = jax.pmap(ResNet.call) def cross_entropy_loss(*, logits, labels): logp = jax.nn.log_softmax(logits) return -jnp.mean(jnp.sum(logp * labels, axis=1)) def loss_fn(params, images, labels): logits = resnet_fn(params, images) return cross_entropy_loss(logits=logits, labels=labels) # Update step, replicated over all GPUs @partial(jax.pmap, axis_name='batch') def update_fn(opt, lr, batch): l, g = jax.value_and_grad(loss_fn)(opt.target, batch['image'], batch['label']) g = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='batch'), g) opt = opt.apply_gradient(g, learning_rate=lr) return opt # In-place update of randomly initialized weights by BiT weigths tf2jax.transform_params(params, params_tf, num_classes=dataset_info['num_classes']) # Create optimizer and replicate it over all GPUs opt = optim.Momentum(beta=0.9).create(params) opt_repl = flax_utils.replicate(opt) # Delete referenes to the objects that are not needed anymore del opt del params total_steps = bit_hyperrule.get_schedule(dataset_info['num_examples'])[-1] # Run training loop for step, batch in zip(range(1, total_steps + 1), data_train.as_numpy_iterator()): lr = bit_hyperrule.get_lr(step - 1, dataset_info['num_examples'], args.base_lr) opt_repl = update_fn(opt_repl, flax_utils.replicate(lr), batch) # Run eval step if ((args.eval_every and step % args.eval_every == 0) or (step == total_steps)): accuracy_test = np.mean([ c for batch in data_test.as_numpy_iterator() for c in (np.argmax( resnet_fn_repl(opt_repl.target, batch['image']), axis=2) == np.argmax(batch['label'], axis=2)).ravel() ]) logger.info(f'Step: {step}, ' f'learning rate: {lr:.07f}, ' f'Test accuracy: {accuracy_test:0.3f}')
def main(args): logger = bit_common.setup_logger(args) if args.test_run: args.batch = 8 args.batch_split = 1 args.workers = 1 logger.info("Args: " + str(args)) # Fix seed # torch.manual_seed(args.seed) # torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = False # np.random.seed(args.seed) # random.seed(args.seed) # Speed up torch.backends.cudnn.banchmark = True device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") logger.info(f"Going to train on {device}") n_train, n_classes, train_loader, valid_loader = mktrainval(args, logger) if args.inpaint != 'none': if args.inpaint == 'mean': inpaint_model = (lambda x, mask: x*mask) elif args.inpaint == 'random': inpaint_model = RandomColorWithNoiseInpainter((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) elif args.inpaint == 'local': inpaint_model = LocalMeanInpainter(window=) elif args.inpaint == 'cagan': inpaint_model = CAInpainter( valid_loader.batch_size, checkpoint_dir='./inpainting_models/release_imagenet_256/') else: raise NotImplementedError(f"Unkown inpaint {args.inpaint}") logger.info(f"Training {args.model}") if args.model in models.KNOWN_MODELS: model = models.KNOWN_MODELS[args.model](head_size=n_classes, zero_head=True) else: # from torchvision model = getattr(torchvision.models, args.model)(pretrained=args.finetune) # Resume fine-tuning if we find a saved model. step = 0 # Optionally resume from a checkpoint. # Load it to CPU first as we'll move the model to GPU later. # This way, we save a little bit of GPU memory when loading. savename = pjoin(args.logdir, args.name, "model.pt") try: logger.info(f"Model will be saved in '{savename}'") checkpoint = torch.load(savename, map_location="cpu") logger.info(f"Found saved model to resume from at '{savename}'") step = checkpoint["step"] model.load_state_dict(checkpoint["model"]) model = model.to(device) # Note: no weight-decay! optim = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9) optim.load_state_dict(checkpoint["optim"]) logger.info(f"Resumed at step {step}") except FileNotFoundError: if args.finetune: logger.info("Fine-tuning from BiT") model.load_from(np.load(f"models/{args.model}.npz")) model = model.to(device) optim = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9) if args.fp16: model, optim = amp.initialize(model, optim, opt_level="O1") logger.info("Moving model onto all GPUs") model = torch.nn.DataParallel(model) optim.zero_grad() model.train() mixup = 0 if args.mixup: mixup = bit_hyperrule.get_mixup(n_train) cri = torch.nn.CrossEntropyLoss().to(device) def counterfact_cri(logit, y): if torch.all(y >= 0): return F.cross_entropy(logit, y, reduction='mean') loss1 = F.cross_entropy(logit[y >= 0], y[y >= 0], reduction='sum') cf_logit, cf_y = logit[y < 0], -(y[y < 0] + 1) # Implement my own logsumexp trick m, _ = torch.max(cf_logit, dim=1, keepdim=True) exp_logit = torch.exp(cf_logit - m) sum_exp_logit = torch.sum(exp_logit, dim=1) eps = 1e-20 num = (sum_exp_logit - exp_logit[torch.arange(exp_logit.shape[0]), cf_y]) num = torch.log(num + eps) denon = torch.log(sum_exp_logit + eps) # Negative log probability loss2 = -(num - denon).sum() return (loss1 + loss2) / y.shape[0] logger.info("Starting training!") chrono = lb.Chrono() accum_steps = 0 mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1 end = time.time() with lb.Uninterrupt() as u: for x, y in recycle(train_loader): # measure data loading time, which is spent in the `for` statement. chrono._done("load", time.time() - end) if u.interrupted: break # Handle inpainting if not isinstance(x, Sample) or x.bbox == [None] * len(x.bbox): criteron = cri else: criteron = counterfact_cri bboxes = x.bbox x = x.img # is_bbox_exists = x.new_ones(x.shape[0], dtype=torch.bool) mask = x.new_ones(x.shape[0], 1, *x.shape[2:]) for i, bbox in enumerate(bboxes): for coord_x, coord_y, w, h in zip(bbox.xs, bbox.ys, bbox.ws, bbox.hs): mask[i, 0, coord_y:(coord_y + h), coord_x:(coord_x + w)] = 0. impute_x = inpaint_model(x, mask) impute_y = (-y - 1) x = torch.cat([x, impute_x], dim=0) # label -1 as negative of class 0, -2 as negative of class 1 etc... y = torch.cat([y, impute_y], dim=0) # Schedule sending to GPU(s) x = x.to(device, non_blocking=True) y = y.to(device, non_blocking=True) # Update learning-rate, including stop training if over. lr = bit_hyperrule.get_lr(step, n_train, args.base_lr) if lr is None: break for param_group in optim.param_groups: param_group["lr"] = lr if mixup > 0.0: x, y_a, y_b = mixup_data(x, y, mixup_l) # compute output with chrono.measure("fprop"): logits = model(x) if mixup > 0.0: c = mixup_criterion(criteron, logits, y_a, y_b, mixup_l) else: c = criteron(logits, y) c_num = float(c.data.cpu().numpy()) # Also ensures a sync point. # Accumulate grads with chrono.measure("grads"): loss = (c / args.batch_split) if args.fp16: with amp.scale_loss(loss, optim) as scaled_loss: scaled_loss.backward() else: loss.backward() accum_steps += 1 accstep = f" ({accum_steps}/{args.batch_split})" if args.batch_split > 1 else "" logger.info(f"[step {step}{accstep}]: loss={c_num:.5f} (lr={lr:.1e})") # pylint: disable=logging-format-interpolation logger.flush() # Update params if accum_steps == args.batch_split: with chrono.measure("update"): optim.step() optim.zero_grad() step += 1 accum_steps = 0 # Sample new mixup ratio for next batch mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1 # Run evaluation and save the model. if args.eval_every and step % args.eval_every == 0: run_eval(model, valid_loader, device, chrono, logger, step) if args.save: torch.save({ "step": step, "model": model.module.state_dict(), "optim": optim.state_dict(), }, savename) end = time.time() # Save model!! if args.save: torch.save({ "step": step, "model": model.module.state_dict(), "optim": optim.state_dict(), }, savename) json.dump({ 'model': args.model, 'head_size': n_classes, 'inpaint': args.inpaint, 'dataset': args.dataset, }, open(pjoin(args.logdir, args.name, 'hyperparams.json'), 'w')) # Final eval at end of training. run_eval(model, valid_loader, device, chrono, logger, step='end') logger.info(f"Timings:\n{chrono}")
def main(): torch.backends.cudnn.benchmark = True device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(device) train_set, valid_set, train_loader, valid_loader = mktrainval() model = models.KNOWN_MODELS["BiT-M-R50x1"](head_size=len( valid_set.classes), zero_head=True) model.load_from(np.load(f"BiT-M-R50x1.npz")) model = torch.nn.DataParallel(model) step = 0 optim = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9) savename = pjoin("./log", "cifar10", "bit.pth.tar") try: checkpoint = torch.load(savename, map_location="cpu") step = checkpoint["step"] model.load_state_dict(checkpoint["model"]) optim.load_state_dict(checkpoint["optim"]) except FileNotFoundError: print('model not fount') model = model.to(device) optim.zero_grad() model.train() mixup = bit_hyperrule.get_mixup(len(train_set)) cri = torch.nn.CrossEntropyLoss().to(device) mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1 all_top1 = [] all_loss = [] all_val_loss = [] for x, y in recycle(train_loader): x = x.to(device, non_blocking=True) y = y.to(device, non_blocking=True) lr = bit_hyperrule.get_lr(step, len(train_set), 0.003) if lr is None: break for param_group in optim.param_groups: param_group["lr"] = lr if mixup > 0.0: x, y_a, y_b = mixup_data(x, y, mixup_l) logits = model(x) if mixup > 0.0: c = mixup_criterion(cri, logits, y_a, y_b, mixup_l) else: c = cri(logits, y) c_num = float(c.data.cpu().numpy()) c.backward() top1, _, _1 = topk(logits, y, ks=(1, 5)) all_top1.extend(top1) print( f"[step {step}]: loss={c_num:.5f} ,accu={np.mean(all_top1):.2%} (lr={lr:.1e})" ) all_loss.append(c_num) all_top1 = [] optim.step() optim.zero_grad() step += 1 mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1 val_loss = run_eval(model, valid_loader, device, step) all_val_loss.append(val_loss) if step % 10 == 0: torch.save( { "step": step, "model": model.state_dict(), "optim": optim.state_dict(), }, savename) print("model save to" + savename) plt.figure(figsize=(8, 8)) plt.plot(range(1, 11), all_loss, label='Training loss') plt.plot(range(1, 11), all_val_loss, label='Validation loss') plt.legend(loc='lower right') plt.title(' loss and step') plt.show()