def main(lr: float = 2e-3, epochs: int = 10, batch_size: int = 128): train_dl, valid_dl, n_features, scaler = get_dataloader( batch_size=batch_size, test_size=0.05) joblib.dump(scaler, CACHE_DIR / "scaler.jbl") print(f"# of Features: {n_features}") model = MlpClassifierModel(n_features).cuda() # model = MoeClassifierModel(n_features).cuda() # optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=3e-2) optimizer = RAdam(model.parameters(), lr=lr) total_steps = len(train_dl) * epochs checkpoints = CheckpointCallback(keep_n_checkpoints=1, checkpoint_dir=CACHE_DIR / "model_cache/", monitor_metric="accuracy") lr_durations = [int(total_steps * 0.1), int(np.ceil(total_steps * 0.9))] break_points = [0] + list(np.cumsum(lr_durations))[:-1] callbacks = [ MovingAverageStatsTrackerCallback(avg_window=len(train_dl), log_interval=len(train_dl) // 2), LearningRateSchedulerCallback( MultiStageScheduler([ LinearLR(optimizer, 0.01, lr_durations[0]), CosineAnnealingLR(optimizer, lr_durations[1], 1e-6) ], start_at_epochs=break_points)), checkpoints, ] class_weights = torch.ones(14).cuda() class_weights[12] = 1.25 # Shot bot = BaseBot(log_dir=CACHE_DIR / "logs", model=model, train_loader=train_dl, valid_loader=valid_dl, clip_grad=10., optimizer=optimizer, echo=True, criterion=nn.CrossEntropyLoss(weight=class_weights), callbacks=callbacks, metrics=(Top1Accuracy(), ), pbar=False, use_tensorboard=False, use_amp=False) bot.train(total_steps=total_steps, checkpoint_interval=len(train_dl)) bot.load_model(checkpoints.best_performers[0][1]) torch.save(bot.model.state_dict(), CACHE_DIR / "final_weights.pth") print("Model saved") checkpoints.remove_checkpoints(keep=0) print("Checkpoint removed")
def finetune(args, model, train_loader, valid_loader, criterion): total_steps = len(train_loader) * args.epochs optimizer = get_optimizer(model, args.lr) if args.debug: print("No decay:", [ n for n, p in model.named_parameters() if any(nd in n for nd in NO_DECAY) ]) checkpoints = CheckpointCallback(keep_n_checkpoints=1, checkpoint_dir=CACHE_DIR / "model_cache/", monitor_metric="loss") lr_durations = [int(total_steps * 0.2), int(np.ceil(total_steps * 0.8))] break_points = [0] + list(np.cumsum(lr_durations))[:-1] callbacks = [ MovingAverageStatsTrackerCallback(avg_window=len(train_loader) // 20, log_interval=len(train_loader) // 16), LearningRateSchedulerCallback( MultiStageScheduler([ LinearLR(optimizer, 0.01, lr_durations[0]), CosineAnnealingLR(optimizer, lr_durations[1]) ], start_at_epochs=break_points)), checkpoints ] if model.linear_transform: callbacks.append(ScalerDebugCallback()) bot = CosineSimilarityBot(model=model, train_loader=train_loader, valid_loader=valid_loader, clip_grad=10., optimizer=optimizer, echo=True, criterion=criterion, callbacks=callbacks, pbar=True, use_tensorboard=False, use_amp=False) bot.logger.info("train batch size: %d", train_loader.batch_size) bot.train(total_steps=total_steps, checkpoint_interval=len(train_loader) // 4) bot.load_model(checkpoints.best_performers[0][1]) checkpoints.remove_checkpoints(keep=0) return bot.model
def train(self, pattern: str = "cache/train/train_*.jl", max_q_len: int = 128, max_ex_len: int = 350, batch_size: int = 4, n_steps: int = 20000, lr: float = 3e-4, grad_accu: int = 1, sample_negatives: float = 1.0, log_freq: int = 200, checkpoint_interval: int = 3000, use_amp: bool = False): (train_ds, train_loader, _, valid_loader, model, optimizer) = self._setup(pattern, max_q_len, max_ex_len, batch_size, lr, sample_negatives, use_amp) checkpoints = CheckpointCallback(keep_n_checkpoints=1, checkpoint_dir=CACHE_DIR / "model_cache/", monitor_metric="loss") lr_durations = [int(n_steps * 0.1), int(np.ceil(n_steps * 0.9))] break_points = [0] + list(np.cumsum(lr_durations))[:-1] callbacks = [ MovingAverageStatsTrackerCallback(avg_window=log_freq * 2, log_interval=log_freq), LearningRateSchedulerCallback( MultiStageScheduler( [ LinearLR(optimizer, 0.01, lr_durations[0]), # LinearLR(optimizer, 0.001, # lr_durations[1], upward=False) CosineAnnealingLR(optimizer, lr_durations[1]) ], start_at_epochs=break_points)), checkpoints # TelegramCallback( # token="559760930:AAGOgPA0OlqlFB7DrX0lyRc4Di3xeixdNO8", # chat_id="213781869", name="QABot", # report_evals=True # ) ] bot = BasicQABot(model=model, train_loader=train_loader, valid_loader=valid_loader, log_dir=CACHE_DIR / "logs/", clip_grad=10., optimizer=optimizer, echo=True, criterion=BasicQALoss(0.5, log_freq=log_freq * 4, alpha=0.005), callbacks=callbacks, pbar=True, use_tensorboard=False, gradient_accumulation_steps=grad_accu, metrics=(), use_amp=use_amp) bot.logger.info("train batch size: %d", train_loader.batch_size) bot.train(total_steps=n_steps, checkpoint_interval=checkpoint_interval) bot.load_model(checkpoints.best_performers[0][1]) checkpoints.remove_checkpoints(keep=0) bot.model.save(CACHE_DIR / "final/")
# In[17]: optimizer = torch.optim.Adam(model.parameters(), lr=2e-5) # In[18]: if APEX_AVAILABLE: model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # In[19]: total_steps = len(train_loader) * 3 checkpoints = CheckpointCallback(keep_n_checkpoints=1, checkpoint_dir=CACHE_DIR / "model_cache/", monitor_metric="accuracy") lr_durations = [int(total_steps * 0.2), int(np.ceil(total_steps * 0.8))] break_points = [0] + list(np.cumsum(lr_durations))[:-1] callbacks = [ MovingAverageStatsTrackerCallback(avg_window=len(train_loader) // 8, log_interval=len(train_loader) // 10), LearningRateSchedulerCallback( MultiStageScheduler([ LinearLR(optimizer, 0.01, lr_durations[0]), CosineAnnealingLR(optimizer, lr_durations[1]) ], start_at_epochs=break_points)), checkpoints ] bot = SST2Bot(model=model,
def train_from_scratch(args, model, train_loader, valid_loader, criterion): total_steps = len(train_loader) * args.epochs optimizer = get_optimizer(model, args) if args.debug: print("No decay:", [ n for n, p in model.named_parameters() if any(nd in n for nd in NO_DECAY) ]) if args.amp: if not APEX_AVAILABLE: raise ValueError("Apex is not installed!") model, optimizer = amp.initialize(model, optimizer, opt_level=args.amp) checkpoints = CheckpointCallback(keep_n_checkpoints=1, checkpoint_dir=CACHE_DIR / "model_cache/", monitor_metric="accuracy") lr_durations = [int(total_steps * 0.25), int(np.ceil(total_steps * 0.75))] break_points = [0] + list(np.cumsum(lr_durations))[:-1] callbacks = [ MovingAverageStatsTrackerCallback(avg_window=len(train_loader) // 5, log_interval=len(train_loader) // 6), LearningRateSchedulerCallback( MultiStageScheduler([ LinearLR(optimizer, 0.01, lr_durations[0]), CosineAnnealingLR(optimizer, lr_durations[1]) ], start_at_epochs=break_points)), checkpoints, LookaheadCallback( ), # this will be a dummy op when Lookahead is not used EarlyStoppingCallback(patience=8, min_improv=1e-2, monitor_metric="accuracy"), WandbCallback(config={ "epochs": args.epochs, "arch": args.arch, "mixup_alpha": args.mixup_alpha, "cutmix_alpha": args.cutmix_alpha, "batch_size": args.batch_size }, name="Imagenatte", watch_freq=200, watch_level="gradients", run_name=args.run_name if args.run_name else None) ] if args.mixup_alpha and args.cutmix_alpha: callbacks.append( RandomCallbackChoices([ MixUpCallback(alpha=args.mixup_alpha, softmax_target=True), CutMixCallback(alpha=args.cutmix_alpha, minmax=(0.2, 0.8), softmax_target=True) ], p=[0.5, 0.5])) else: if args.mixup_alpha: callbacks.append( MixUpCallback(alpha=args.mixup_alpha, softmax_target=True)) if args.cutmix_alpha: callbacks.append( CutMixCallback(alpha=args.cutmix_alpha, minmax=None, softmax_target=True)) if BOT_TOKEN: callbacks.append( TelegramCallback(token=BOT_TOKEN, chat_id=CHAT_ID, name="Imagenette", report_evals=True)) bot = ImageClassificationBot(model=model, train_loader=train_loader, valid_loader=valid_loader, clip_grad=10., optimizer=optimizer, echo=True, criterion=criterion, callbacks=callbacks, pbar=False, use_tensorboard=True, use_amp=(args.amp != '')) bot.train(total_steps=total_steps, checkpoint_interval=len(train_loader) // 2) bot.load_model(checkpoints.best_performers[0][1]) torch.save(bot.model.state_dict(), CACHE_DIR / f"final_weights.pth") checkpoints.remove_checkpoints(keep=0)