def _load_checkpoint(*, filename, state: State): if os.path.isfile(filename): print(f"=> loading checkpoint {filename}") checkpoint = utils.load_checkpoint(filename) if not state.stage_name.startswith("infer"): state.stage_name = checkpoint["stage_name"] state.epoch = checkpoint["epoch"] state.global_epoch = checkpoint["global_epoch"] # @TODO: should we also load, # checkpoint_data, main_metric, minimize_metric, valid_loader ? # epoch_metrics, valid_metrics ? utils.unpack_checkpoint(checkpoint, model=state.model, criterion=state.criterion, optimizer=state.optimizer, scheduler=state.scheduler) print(f"loaded checkpoint {filename} " f"(global epoch {checkpoint['global_epoch']}, " f"epoch {checkpoint['epoch']}, " f"stage {checkpoint['stage_name']})") else: raise Exception(f"No checkpoint found at {filename}")
def load_checkpoint(*, filename, state: _State): if os.path.isfile(filename): print(f"=> loading checkpoint {filename}") checkpoint = utils.load_checkpoint(filename) if not state.stage.startswith("infer"): state.epoch = checkpoint["epoch"] state.stage_epoch = checkpoint["stage_epoch"] state.stage = checkpoint["stage"] utils.unpack_checkpoint( checkpoint, model=state.model, criterion=state.criterion, optimizer=state.optimizer, scheduler=state.scheduler ) print( f"loaded checkpoint {filename} " f"(epoch {checkpoint['epoch']}, " f"stage_epoch {checkpoint['stage_epoch']}, " f"stage {checkpoint['stage']})" ) else: raise Exception(f"No checkpoint found at {filename}")
def trace_model_from_checkpoint(logdir, method_name): config_path = logdir / "configs/_config.json" checkpoint_path = logdir / "checkpoints/best.pth" print("Load config") config: Dict[str, dict] = safitty.load(config_path) # Get expdir name config_expdir = Path(config["args"]["expdir"]) # We will use copy of expdir from logs for reproducibility expdir_from_logs = Path(logdir) / "code" / config_expdir.name print("Import experiment and runner from logdir") ExperimentType, RunnerType = \ import_experiment_and_runner(expdir_from_logs) experiment: Experiment = ExperimentType(config) print("Load model state from checkpoints/best.pth") model = experiment.get_model(next(iter(experiment.stages))) checkpoint = utils.load_checkpoint(checkpoint_path) utils.unpack_checkpoint(checkpoint, model=model) print("Tracing") traced = trace_model(model, experiment, RunnerType, method_name) print("Done") return traced
def run_whole_training(experiment_name: str, exp_config: ExperimenetConfig, runs_dir="runs"): model = get_model(exp_config.model_name, dropout=exp_config.dropout).cuda() if exp_config.transfer_from_checkpoint: transfer_checkpoint = fs.auto_file(exp_config.transfer_from_checkpoint) print("Transferring weights from model checkpoint", transfer_checkpoint) checkpoint = load_checkpoint(transfer_checkpoint) pretrained_dict = checkpoint["model_state_dict"] transfer_weights(model, pretrained_dict) if exp_config.resume_from_checkpoint: checkpoint = load_checkpoint( fs.auto_file(exp_config.resume_from_checkpoint)) unpack_checkpoint(checkpoint, model=model) print("Loaded model weights from:", exp_config.resume_from_checkpoint) report_checkpoint(checkpoint) experiment_dir = os.path.join(runs_dir, experiment_name) os.makedirs(experiment_dir, exist_ok=False) config_fname = os.path.join(experiment_dir, f"config.json") with open(config_fname, "w") as f: f.write(json.dumps(jsonpickle.encode(exp_config), indent=2)) for stage in exp_config.stages: run_stage_training(model, stage, exp_config, experiment_dir=experiment_dir)
def trace_model_from_checkpoint(logdir, logger, method_name='forward'): config_path = f'{logdir}/configs/_config.json' checkpoint_path = f'{logdir}/checkpoints/best.pth' logger.info('Load config') config = safitty.load(config_path) if 'distributed_params' in config: del config['distributed_params'] # Get expdir name # noinspection SpellCheckingInspection,PyTypeChecker # We will use copy of expdir from logs for reproducibility expdir_name = os.path.basename(config['args']['expdir']) expdir_from_logs = os.path.abspath(join(logdir, '../', expdir_name)) logger.info('Import experiment and runner from logdir') ExperimentType, RunnerType = \ import_experiment_and_runner(Path(expdir_from_logs)) experiment: Experiment = ExperimentType(config) logger.info('Load model state from checkpoints/best.pth') model = experiment.get_model(next(iter(experiment.stages))) checkpoint = utils.load_checkpoint(checkpoint_path) utils.unpack_checkpoint(checkpoint, model=model) logger.info('Tracing') traced = trace_model(model, experiment, RunnerType, method_name) logger.info('Done') return traced
def train(self, dataset, remaining_time_budget=None): if self.done_training: return self.curr_epoch += 1 model = OneHeadNet(**self.config['model_params']).to(self.device) optimizer = optim.Adam(model.parameters(), lr=self.learning_rate) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5, 10, 20, 30, 40], gamma=0.3) if self.checkpoint is not None: unpack_checkpoint(self.checkpoint, model=model, optimizer=optimizer, scheduler=scheduler) data = self.get_loader( dataset, transform=self.config['transform']['train'], train=True, epoch_frac=0.5 ) train_losses, elapsed_time = self._train_epoch( data=data, model=model, optimizer=optimizer, criterion=F.binary_cross_entropy_with_logits, scheduler=scheduler ) self.checkpoint = pack_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler) msg = '{curr}/{total} * epoch | '.format(curr=self.curr_epoch, total=self.total_epochs) msg = format_log_message(msg, train_loss=np.mean(train_losses), elapsed_time=elapsed_time) logger.info(msg) if self.curr_epoch >= self.total_epochs: self.done_training = True
def trace_model_from_checkpoint(logdir, logger, method_name='forward', file='best'): config_path = f'{logdir}/configs/_config.json' checkpoint_path = f'{logdir}/checkpoints/{file}.pth' logger.info('Load config') config = safitty.load(config_path) if 'distributed_params' in config: del config['distributed_params'] # Get expdir name # noinspection SpellCheckingInspection,PyTypeChecker # We will use copy of expdir from logs for reproducibility expdir_name = config['args']['expdir'] logger.info(f'expdir_name from args: {expdir_name}') sys.path.insert(0, os.path.abspath(join(logdir, '../'))) expdir_from_logs = os.path.abspath(join(logdir, '../', expdir_name)) logger.info(f'expdir_from_logs: {expdir_from_logs}') logger.info('Import experiment and runner from logdir') ExperimentType, RunnerType = \ import_experiment_and_runner(Path(expdir_from_logs)) experiment: Experiment = ExperimentType(config) logger.info(f'Load model state from checkpoints/{file}.pth') model = experiment.get_model(next(iter(experiment.stages))) checkpoint = utils.load_checkpoint(checkpoint_path) utils.unpack_checkpoint(checkpoint, model=model) device = 'cpu' stage = list(experiment.stages)[0] loader = 0 mode = 'eval' requires_grad = False opt_level = None runner: RunnerType = RunnerType() runner.model, runner.device = model, device batch = experiment.get_native_batch(stage, loader) logger.info('Tracing') traced = trace_model( model, runner, batch, method_name=method_name, mode=mode, requires_grad=requires_grad, opt_level=opt_level, device=device, ) logger.info('Done') return traced
def quantize_model_from_checkpoint( logdir: Path, checkpoint_name: str, stage: str = None, qconfig_spec: Optional[Union[Set, Dict]] = None, dtype: Optional[torch.dtype] = torch.qint8, backend: str = None, ) -> Model: """ Quantize model using created experiment and runner. Args: logdir (Union[str, Path]): Path to Catalyst logdir with model checkpoint_name (str): Name of model checkpoint to use stage (str): experiment's stage name qconfig_spec: torch.quantization.quantize_dynamic parameter, you can define layers to be quantize dtype: type of the model parameters, default int8 backend: defines backend for quantization Returns: Quantized model """ if backend is not None: torch.backends.quantized.engine = backend config_path = logdir / "configs" / "_config.json" checkpoint_path = logdir / "checkpoints" / f"{checkpoint_name}.pth" logging.info("Load config") config: Dict[str, dict] = load_config(config_path) # Get expdir name config_expdir = Path(config["args"]["expdir"]) # We will use copy of expdir from logs for reproducibility expdir = Path(logdir) / "code" / config_expdir.name logger.info("Import experiment and runner from logdir") experiment: ConfigExperiment = None experiment, _, _ = prepare_config_api_components(expdir=expdir, config=config) logger.info(f"Load model state from checkpoints/{checkpoint_name}.pth") if stage is None: stage = list(experiment.stages)[0] model = experiment.get_model(stage) checkpoint = load_checkpoint(checkpoint_path) unpack_checkpoint(checkpoint, model=model) logger.info("Quantization is running...") quantized_model = quantization.quantize_dynamic( model.cpu(), qconfig_spec=qconfig_spec, dtype=dtype, ) logger.info("Done") return quantized_model
def best(self) -> nn.Module: model = self.model checkpoint_path = self.project_dir / self.default_logdir / 'checkpoints' / 'best_full.pth' ckpt = load_checkpoint(checkpoint_path) unpack_checkpoint(ckpt, model) thresholds_path = self.project_dir / self.default_logdir / 'tuned_params.pkl' if not thresholds_path.exists(): return model tuned_params = torch.load(thresholds_path) self.task_flow.get_decoder().load_tuned(tuned_params) return model
def test(self, dataset, remaining_time_budget=None): model = OneHeadNet(**self.config['model_params']).to(self.device) unpack_checkpoint(self.checkpoint, model=model) model.eval() predictions = [] data = self.get_loader(dataset, transform=self.config['transform']['test'], train=False) with torch.no_grad(): for i, inputs in enumerate(data): sample = self.to_device(inputs['features']) y_pred = torch.sigmoid(model(sample)).cpu().numpy() predictions.extend(y_pred) return np.array(predictions)
def model_fn(model_dir): model_path = path.join(model_dir, checkpoint_fname) # '/opt/ml/model/model.pth' # already available in this method torch.load(model_path, map_location=lambda storage, loc: storage) checkpoint = load_checkpoint(model_path) params = checkpoint['checkpoint_data']['cmd_args'] model_name = 'seresnext50d_gap' if model_name is None: model_name = params['model'] coarse_grading = params.get('coarse', False) # torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = False CLASS_NAMES = get_class_names(coarse_grading=coarse_grading) num_classes = len(CLASS_NAMES) model = get_model(model_name, pretrained=False, num_classes=num_classes) unpack_checkpoint(checkpoint, model=model) report_checkpoint(checkpoint) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) model = model.eval() if apply_softmax: model = nn.Sequential(model, ApplySoftmaxToLogits()) if tta == 'flip' or tta == 'fliplr': model = FlipLRMultiheadTTA(model) if tta == 'flip4': model = Flip4MultiheadTTA(model) if tta == 'fliplr_ms': model = MultiscaleFlipLRMultiheadTTA(model) with torch.no_grad(): if torch.cuda.is_available(): model = model.cuda() if torch.cuda.device_count() > 1: model = nn.DataParallel( model, device_ids=[id for id in range(torch.cuda.device_count())]) return model
def load_checkpoint(*, filename, state: _State): if os.path.isfile(filename): print(f"=> loading checkpoint {filename}") checkpoint = utils.load_checkpoint(filename) state.epoch = checkpoint["epoch"] utils.unpack_checkpoint(checkpoint, model=state.model, criterion=state.criterion, optimizer=state.optimizer, scheduler=state.scheduler) print( f"loaded checkpoint {filename} (epoch {checkpoint['epoch']})") else: raise Exception(f"No checkpoint found at {filename}")
def trace_model_from_checkpoint( logdir: Path, method_name: str, checkpoint_name: str, mode: str = "eval", requires_grad: bool = False, ): config_path = logdir / "configs" / "_config.json" checkpoint_path = logdir / "checkpoints" / f"{checkpoint_name}.pth" print("Load config") config: Dict[str, dict] = safitty.load(config_path) # Get expdir name config_expdir = safitty.get(config, "args", "expdir", apply=Path) # We will use copy of expdir from logs for reproducibility expdir = Path(logdir) / "code" / config_expdir.name print("Import experiment and runner from logdir") ExperimentType, RunnerType = import_experiment_and_runner(expdir) experiment: Experiment = ExperimentType(config) print(f"Load model state from checkpoints/{checkpoint_name}.pth") model = experiment.get_model(next(iter(experiment.stages))) checkpoint = utils.load_checkpoint(checkpoint_path) utils.unpack_checkpoint(checkpoint, model=model) print("Tracing") traced = trace_model( model, experiment, RunnerType, method_name=method_name, mode=mode, requires_grad=requires_grad, ) print("Done") return traced
def _load( self, runner: "IRunner", resume_logpath: Any = None, resume_model: str = None, resume_runner: str = None, ): if resume_logpath is not None: runner.engine.wait_for_everyone() if self.mode == "model": try: unwrapped_model = runner.engine.unwrap_model(runner.model) unwrapped_model.load_state_dict( load_checkpoint(resume_logpath)) except BaseException: checkpoint = load_checkpoint(resume_logpath) unpack_checkpoint(checkpoint=checkpoint, model=runner.model) else: checkpoint = load_checkpoint(resume_logpath) unpack_checkpoint(checkpoint=checkpoint, model=runner.model) if resume_runner is not None: runner.engine.wait_for_everyone() checkpoint = load_checkpoint(resume_runner) unpack_checkpoint( checkpoint=checkpoint, model=runner.model, criterion=runner.criterion, optimizer=runner.optimizer, scheduler=runner.scheduler, ) runner.epoch_step = checkpoint["epoch_step"] runner.batch_step = checkpoint["batch_step"] runner.sample_step = checkpoint["sample_step"] if resume_model is not None: runner.engine.wait_for_everyone() unwrapped_model = runner.engine.unwrap_model(runner.model) unwrapped_model.load_state_dict(load_checkpoint(resume_model))
def main(): global args, logger args = get_parser() check(args) logger = get_logger() os.environ["CUDA_VISIBLE_DEVICES"] = ','.join( str(x) for x in args.test_gpu) logger.info(args) logger.info("=> creating model ...") logger.info("Classes: {}".format(args.classes)) value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] gray_folder = os.path.join(args.save_folder, 'gray') color_folder = os.path.join(args.save_folder, 'color') test_transform = transform.Compose([transform.ToTensor()]) test_data = dataset.SemData(split=args.split, data_root=args.data_root, data_list=args.test_list, transform=test_transform) index_start = args.index_start if args.index_step == 0: index_end = len(test_data.data_list) else: index_end = min(index_start + args.index_step, len(test_data.data_list)) test_data.data_list = test_data.data_list[index_start:index_end] test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True) colors = np.loadtxt(args.colors_path).astype('uint8') names = [line.rstrip('\n') for line in open(args.names_path)] if not args.has_prediction: print('arch: ', args.arch) if args.arch == 'psp': from model.pspnet import PSPNet model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, pretrained=False) elif args.arch == 'psa': from model.psanet import PSANet model = PSANet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, compact=args.compact, shrink_factor=args.shrink_factor, mask_h=args.mask_h, mask_w=args.mask_w, normalization_factor=args.normalization_factor, psa_softmax=args.psa_softmax, pretrained=False) elif 'of' in args.arch: from catalyst_cityscapes.model.mymodel import ofPSPNet print('ofPSPNet') model = ofPSPNet(encoder_name=str(args.layers), classes=args.classes, zoom_factor=args.zoom_factor, pretrained=False) elif 'smp' in args.arch: from catalyst_cityscapes.model.mymodel import smpPSPNet print('smpPSPNet') model = smpPSPNet(encoder_name='resnet%d' % args.layers, classes=args.classes) # logger.info(model) model = torch.nn.DataParallel(model).cuda() cudnn.benchmark = True if os.path.isfile(args.model_path): logger.info("=> loading checkpoint '{}'".format(args.model_path)) if 'of' in args.arch or 'smp' in args.arch: from catalyst import utils checkpoint = utils.load_checkpoint(args.model_path) print('checkpoint keys', list(checkpoint)) utils.unpack_checkpoint(checkpoint, model=model) else: checkpoint = torch.load(args.model_path) model.load_state_dict(checkpoint['state_dict'], strict=False) logger.info("=> loaded checkpoint '{}'".format(args.model_path)) else: raise RuntimeError("=> no checkpoint found at '{}'".format( args.model_path)) test(test_loader, test_data.data_list, model, args.classes, mean, std, args.base_size, args.test_h, args.test_w, args.scales, gray_folder, color_folder, colors, is_med=args.get('is_med', False)) if args.split != 'test': cal_acc(test_data.data_list, gray_folder, args.classes, names, is_med=args.get('is_med', False), label_mapping=args.get('label_mapping', None))
def trace_model_from_runner( runner: IRunner, checkpoint_name: str = None, method_name: str = "forward", mode: str = "eval", requires_grad: bool = False, opt_level: str = None, device: Device = "cpu", ) -> ScriptModule: """ Traces model using created experiment and runner. Args: runner (Runner): Current runner. checkpoint_name (str): Name of model checkpoint to use, if None traces current model from runner method_name (str): Model's method name that will be used as entrypoint during tracing mode (str): Mode for model to trace (``train`` or ``eval``) requires_grad (bool): Flag to use grads opt_level (str): AMP FP16 init level device (str): Torch device Returns: (ScriptModule): Traced model """ logdir = runner.logdir model = get_nn_from_ddp_module(runner.model) if checkpoint_name is not None: dumped_checkpoint = pack_checkpoint(model=model) checkpoint_path = logdir / "checkpoints" / f"{checkpoint_name}.pth" checkpoint = load_checkpoint(filepath=checkpoint_path) unpack_checkpoint(checkpoint=checkpoint, model=model) # getting input names of args for method since we don't have Runner # and we don't know input_key to preprocess batch for method call fn = getattr(model, method_name) method_argnames = _get_input_argnames(fn=fn, exclude=["self"]) batch = {} for name in method_argnames: # TODO: We don't know input_keys without runner assert name in runner.input, ( "Input batch should contain the same keys as input argument " "names of `forward` function to be traced correctly") batch[name] = runner.input[name] batch = any2device(batch, device) # Dumping previous runner of the model, we will need it to restore _device, _is_training, _requires_grad = ( runner.device, model.training, get_requires_grad(model), ) model.to(device) # Function to run prediction on batch def predict_fn(model: Model, inputs, **kwargs): return model(**inputs, **kwargs) traced_model = trace_model( model=model, predict_fn=predict_fn, batch=batch, method_name=method_name, mode=mode, requires_grad=requires_grad, opt_level=opt_level, device=device, ) if checkpoint_name is not None: unpack_checkpoint(checkpoint=dumped_checkpoint, model=model) # Restore previous runner of the model getattr(model, "train" if _is_training else "eval")() set_requires_grad(model, _requires_grad) model.to(_device) return traced_model
def trace_model_from_checkpoint( logdir: Path, method_name: str, checkpoint_name: str, stage: str = None, loader: Union[str, int] = None, mode: str = "eval", requires_grad: bool = False, opt_level: str = None, device: Device = "cpu", ): """ Traces model using created experiment and runner. Args: logdir (Union[str, Path]): Path to Catalyst logdir with model checkpoint_name (str): Name of model checkpoint to use stage (str): experiment's stage name loader (Union[str, int]): experiment's loader name or its index method_name (str): Model's method name that will be used as entrypoint during tracing mode (str): Mode for model to trace (``train`` or ``eval``) requires_grad (bool): Flag to use grads opt_level (str): AMP FP16 init level device (str): Torch device Returns: the traced model """ config_path = logdir / "configs" / "_config.json" checkpoint_path = logdir / "checkpoints" / f"{checkpoint_name}.pth" print("Load config") config: Dict[str, dict] = load_config(config_path) runner_params = config.get("runner_params", {}) or {} # Get expdir name config_expdir = Path(config["args"]["expdir"]) # We will use copy of expdir from logs for reproducibility expdir = Path(logdir) / "code" / config_expdir.name print("Import experiment and runner from logdir") ExperimentType, RunnerType = import_experiment_and_runner(expdir) experiment: ConfigExperiment = ExperimentType(config) print(f"Load model state from checkpoints/{checkpoint_name}.pth") if stage is None: stage = list(experiment.stages)[0] model = experiment.get_model(stage) checkpoint = load_checkpoint(checkpoint_path) unpack_checkpoint(checkpoint, model=model) runner: RunnerType = RunnerType(**runner_params) runner.model, runner.device = model, device if loader is None: loader = 0 batch = get_native_batch_from_loaders( loaders=experiment.get_loaders(stage), loader=loader) # function to run prediction on batch def predict_fn(model, inputs, **kwargs): _model = runner.model runner.model = model result = runner.predict_batch(inputs, **kwargs) runner.model = _model return result print("Tracing") traced_model = trace_model( model=model, predict_fn=predict_fn, batch=batch, method_name=method_name, mode=mode, requires_grad=requires_grad, opt_level=opt_level, device=device, ) print("Done") return traced_model
def main(): parser = argparse.ArgumentParser() parser.add_argument('--seed', type=int, default=42, help='Random seed') parser.add_argument('--fast', action='store_true') parser.add_argument('--mixup', action='store_true') parser.add_argument('--balance', action='store_true') parser.add_argument('--balance-datasets', action='store_true') parser.add_argument('--swa', action='store_true') parser.add_argument('--show', action='store_true') parser.add_argument('--use-idrid', action='store_true') parser.add_argument('--use-messidor', action='store_true') parser.add_argument('--use-aptos2015', action='store_true') parser.add_argument('--use-aptos2019', action='store_true') parser.add_argument('-v', '--verbose', action='store_true') parser.add_argument('--coarse', action='store_true') parser.add_argument('-acc', '--accumulation-steps', type=int, default=1, help='Number of batches to process') parser.add_argument('-dd', '--data-dir', type=str, default='data', help='Data directory') parser.add_argument('-m', '--model', type=str, default='resnet18_gap', help='') parser.add_argument('-b', '--batch-size', type=int, default=8, help='Batch Size during training, e.g. -b 64') parser.add_argument('-e', '--epochs', type=int, default=100, help='Epoch to run') parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement') parser.add_argument('-f', '--fold', action='append', type=int, default=None) parser.add_argument('-ft', '--fine-tune', default=0, type=int) parser.add_argument('-lr', '--learning-rate', type=float, default=1e-4, help='Initial learning rate') parser.add_argument('--criterion-reg', type=str, default=None, nargs='+', help='Criterion') parser.add_argument('--criterion-ord', type=str, default=None, nargs='+', help='Criterion') parser.add_argument('--criterion-cls', type=str, default=['ce'], nargs='+', help='Criterion') parser.add_argument('-l1', type=float, default=0, help='L1 regularization loss') parser.add_argument('-l2', type=float, default=0, help='L2 regularization loss') parser.add_argument('-o', '--optimizer', default='Adam', help='Name of the optimizer') parser.add_argument('-p', '--preprocessing', default=None, help='Preprocessing method') parser.add_argument( '-c', '--checkpoint', type=str, default=None, help='Checkpoint filename to use as initial model weights') parser.add_argument('-w', '--workers', default=multiprocessing.cpu_count(), type=int, help='Num workers') parser.add_argument('-a', '--augmentations', default='medium', type=str, help='') parser.add_argument('-tta', '--tta', default=None, type=str, help='Type of TTA to use [fliplr, d4]') parser.add_argument('-t', '--transfer', default=None, type=str, help='') parser.add_argument('--fp16', action='store_true') parser.add_argument('-s', '--scheduler', default='multistep', type=str, help='') parser.add_argument('--size', default=512, type=int, help='Image size for training & inference') parser.add_argument('-wd', '--weight-decay', default=0, type=float, help='L2 weight decay') parser.add_argument('-wds', '--weight-decay-step', default=None, type=float, help='L2 weight decay step to add after each epoch') parser.add_argument('-d', '--dropout', default=0.0, type=float, help='Dropout before head layer') parser.add_argument( '--warmup', default=0, type=int, help= 'Number of warmup epochs with 0.1 of the initial LR and frozed encoder' ) parser.add_argument('-x', '--experiment', default=None, type=str, help='Dropout before head layer') args = parser.parse_args() data_dir = args.data_dir num_workers = args.workers num_epochs = args.epochs batch_size = args.batch_size learning_rate = args.learning_rate l1 = args.l1 l2 = args.l2 early_stopping = args.early_stopping model_name = args.model optimizer_name = args.optimizer image_size = (args.size, args.size) fast = args.fast augmentations = args.augmentations fp16 = args.fp16 fine_tune = args.fine_tune criterion_reg_name = args.criterion_reg criterion_cls_name = args.criterion_cls criterion_ord_name = args.criterion_ord folds = args.fold mixup = args.mixup balance = args.balance balance_datasets = args.balance_datasets use_swa = args.swa show_batches = args.show scheduler_name = args.scheduler verbose = args.verbose weight_decay = args.weight_decay use_idrid = args.use_idrid use_messidor = args.use_messidor use_aptos2015 = args.use_aptos2015 use_aptos2019 = args.use_aptos2019 warmup = args.warmup dropout = args.dropout use_unsupervised = False experiment = args.experiment preprocessing = args.preprocessing weight_decay_step = args.weight_decay_step coarse_grading = args.coarse class_names = get_class_names(coarse_grading) assert use_aptos2015 or use_aptos2019 or use_idrid or use_messidor current_time = datetime.now().strftime('%b%d_%H_%M') random_name = get_random_name() if folds is None or len(folds) == 0: folds = [None] for fold in folds: torch.cuda.empty_cache() checkpoint_prefix = f'{model_name}_{args.size}_{augmentations}' if preprocessing is not None: checkpoint_prefix += f'_{preprocessing}' if use_aptos2019: checkpoint_prefix += '_aptos2019' if use_aptos2015: checkpoint_prefix += '_aptos2015' if use_messidor: checkpoint_prefix += '_messidor' if use_idrid: checkpoint_prefix += '_idrid' if coarse_grading: checkpoint_prefix += '_coarse' if fold is not None: checkpoint_prefix += f'_fold{fold}' checkpoint_prefix += f'_{random_name}' if experiment is not None: checkpoint_prefix = experiment directory_prefix = f'{current_time}/{checkpoint_prefix}' log_dir = os.path.join('runs', directory_prefix) os.makedirs(log_dir, exist_ok=False) config_fname = os.path.join(log_dir, f'{checkpoint_prefix}.json') with open(config_fname, 'w') as f: train_session_args = vars(args) f.write(json.dumps(train_session_args, indent=2)) set_manual_seed(args.seed) num_classes = len(class_names) model = get_model(model_name, num_classes=num_classes, dropout=dropout).cuda() if args.transfer: transfer_checkpoint = fs.auto_file(args.transfer) print("Transfering weights from model checkpoint", transfer_checkpoint) checkpoint = load_checkpoint(transfer_checkpoint) pretrained_dict = checkpoint['model_state_dict'] for name, value in pretrained_dict.items(): try: model.load_state_dict(collections.OrderedDict([(name, value)]), strict=False) except Exception as e: print(e) report_checkpoint(checkpoint) if args.checkpoint: checkpoint = load_checkpoint(fs.auto_file(args.checkpoint)) unpack_checkpoint(checkpoint, model=model) report_checkpoint(checkpoint) train_ds, valid_ds, train_sizes = get_datasets( data_dir=data_dir, use_aptos2019=use_aptos2019, use_aptos2015=use_aptos2015, use_idrid=use_idrid, use_messidor=use_messidor, use_unsupervised=False, coarse_grading=coarse_grading, image_size=image_size, augmentation=augmentations, preprocessing=preprocessing, target_dtype=int, fold=fold, folds=4) train_loader, valid_loader = get_dataloaders( train_ds, valid_ds, batch_size=batch_size, num_workers=num_workers, train_sizes=train_sizes, balance=balance, balance_datasets=balance_datasets, balance_unlabeled=False) loaders = collections.OrderedDict() loaders["train"] = train_loader loaders["valid"] = valid_loader print('Datasets :', data_dir) print(' Train size :', len(train_loader), len(train_loader.dataset)) print(' Valid size :', len(valid_loader), len(valid_loader.dataset)) print(' Aptos 2019 :', use_aptos2019) print(' Aptos 2015 :', use_aptos2015) print(' IDRID :', use_idrid) print(' Messidor :', use_messidor) print('Train session :', directory_prefix) print(' FP16 mode :', fp16) print(' Fast mode :', fast) print(' Mixup :', mixup) print(' Balance cls. :', balance) print(' Balance ds. :', balance_datasets) print(' Warmup epoch :', warmup) print(' Train epochs :', num_epochs) print(' Fine-tune ephs :', fine_tune) print(' Workers :', num_workers) print(' Fold :', fold) print(' Log dir :', log_dir) print(' Augmentations :', augmentations) print('Model :', model_name) print(' Parameters :', count_parameters(model)) print(' Image size :', image_size) print(' Dropout :', dropout) print(' Classes :', class_names, num_classes) print('Optimizer :', optimizer_name) print(' Learning rate :', learning_rate) print(' Batch size :', batch_size) print(' Criterion (cls):', criterion_cls_name) print(' Criterion (reg):', criterion_reg_name) print(' Criterion (ord):', criterion_ord_name) print(' Scheduler :', scheduler_name) print(' Weight decay :', weight_decay, weight_decay_step) print(' L1 reg. :', l1) print(' L2 reg. :', l2) print(' Early stopping :', early_stopping) # model training callbacks = [] criterions = {} main_metric = 'cls/kappa' if criterion_reg_name is not None: cb, crits = get_reg_callbacks(criterion_reg_name, class_names=class_names, show=show_batches) callbacks += cb criterions.update(crits) if criterion_ord_name is not None: cb, crits = get_ord_callbacks(criterion_ord_name, class_names=class_names, show=show_batches) callbacks += cb criterions.update(crits) if criterion_cls_name is not None: cb, crits = get_cls_callbacks(criterion_cls_name, num_classes=num_classes, num_epochs=num_epochs, class_names=class_names, show=show_batches) callbacks += cb criterions.update(crits) if l1 > 0: callbacks += [ LPRegularizationCallback(start_wd=l1, end_wd=l1, schedule=None, prefix='l1', p=1) ] if l2 > 0: callbacks += [ LPRegularizationCallback(start_wd=l2, end_wd=l2, schedule=None, prefix='l2', p=2) ] callbacks += [CustomOptimizerCallback()] runner = SupervisedRunner(input_key='image') # Pretrain/warmup if warmup: set_trainable(model.encoder, False, False) optimizer = get_optimizer('Adam', get_optimizable_parameters(model), learning_rate=learning_rate * 0.1) runner.train(fp16=fp16, model=model, criterion=criterions, optimizer=optimizer, scheduler=None, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, 'warmup'), num_epochs=warmup, verbose=verbose, main_metric=main_metric, minimize_metric=False, checkpoint_data={"cmd_args": vars(args)}) del optimizer # Main train if num_epochs: set_trainable(model.encoder, True, False) optimizer = get_optimizer(optimizer_name, get_optimizable_parameters(model), learning_rate=learning_rate, weight_decay=weight_decay) if use_swa: from torchcontrib.optim import SWA optimizer = SWA(optimizer, swa_start=len(train_loader), swa_freq=512) scheduler = get_scheduler(scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(train_loader)) # Additional callbacks that specific to main stage only added here to copy of callbacks main_stage_callbacks = callbacks if early_stopping: es_callback = EarlyStoppingCallback(early_stopping, min_delta=1e-4, metric=main_metric, minimize=False) main_stage_callbacks = callbacks + [es_callback] runner.train(fp16=fp16, model=model, criterion=criterions, optimizer=optimizer, scheduler=scheduler, callbacks=main_stage_callbacks, loaders=loaders, logdir=os.path.join(log_dir, 'main'), num_epochs=num_epochs, verbose=verbose, main_metric=main_metric, minimize_metric=False, checkpoint_data={"cmd_args": vars(args)}) del optimizer, scheduler best_checkpoint = os.path.join(log_dir, 'main', 'checkpoints', 'best.pth') model_checkpoint = os.path.join(log_dir, 'main', 'checkpoints', f'{checkpoint_prefix}.pth') clean_checkpoint(best_checkpoint, model_checkpoint) # Restoring best model from checkpoint checkpoint = load_checkpoint(best_checkpoint) unpack_checkpoint(checkpoint, model=model) report_checkpoint(checkpoint) # Stage 3 - Fine tuning if fine_tune: set_trainable(model.encoder, False, False) optimizer = get_optimizer(optimizer_name, get_optimizable_parameters(model), learning_rate=learning_rate) scheduler = get_scheduler('multistep', optimizer, lr=learning_rate, num_epochs=fine_tune, batches_in_epoch=len(train_loader)) runner.train(fp16=fp16, model=model, criterion=criterions, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, 'finetune'), num_epochs=fine_tune, verbose=verbose, main_metric=main_metric, minimize_metric=False, checkpoint_data={"cmd_args": vars(args)}) best_checkpoint = os.path.join(log_dir, 'finetune', 'checkpoints', 'best.pth') model_checkpoint = os.path.join(log_dir, 'finetune', 'checkpoints', f'{checkpoint_prefix}.pth') clean_checkpoint(best_checkpoint, model_checkpoint)
def main(): parser = argparse.ArgumentParser() ########################################################################################### # Distributed-training related stuff parser.add_argument("--local_rank", type=int, default=0) ########################################################################################### parser.add_argument("-acc", "--accumulation-steps", type=int, default=1, help="Number of batches to process") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument("--fast", action="store_true") parser.add_argument( "-dd", "--data-dir", type=str, help="Data directory for INRIA sattelite dataset", default=os.environ.get("INRIA_DATA_DIR"), ) parser.add_argument( "-dd-xview2", "--data-dir-xview2", type=str, required=False, help="Data directory for external xView2 dataset" ) parser.add_argument("-m", "--model", type=str, default="b6_unet32_s2", help="") parser.add_argument("-b", "--batch-size", type=int, default=8, help="Batch Size during training, e.g. -b 64") parser.add_argument("-e", "--epochs", type=int, default=100, help="Epoch to run") # parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement') # parser.add_argument('-fe', '--freeze-encoder', type=int, default=0, help='Freeze encoder parameters for N epochs') # parser.add_argument('-ft', '--fine-tune', action='store_true') parser.add_argument("-lr", "--learning-rate", type=float, default=1e-3, help="Initial learning rate") parser.add_argument("-l", "--criterion", type=str, required=True, action="append", nargs="+", help="Criterion") parser.add_argument( "-l2", "--criterion2", type=str, required=False, action="append", nargs="+", help="Criterion for stride 2 mask", ) parser.add_argument( "-l4", "--criterion4", type=str, required=False, action="append", nargs="+", help="Criterion for stride 4 mask", ) parser.add_argument( "-l8", "--criterion8", type=str, required=False, action="append", nargs="+", help="Criterion for stride 8 mask", ) parser.add_argument( "-l16", "--criterion16", type=str, required=False, action="append", nargs="+", help="Criterion for stride 16 mask", ) parser.add_argument("-o", "--optimizer", default="RAdam", help="Name of the optimizer") parser.add_argument( "-c", "--checkpoint", type=str, default=None, help="Checkpoint filename to use as initial model weights" ) parser.add_argument("-w", "--workers", default=8, type=int, help="Num workers") parser.add_argument("-a", "--augmentations", default="hard", type=str, help="") parser.add_argument("-tm", "--train-mode", default="random", type=str, help="") parser.add_argument("--run-mode", default="fit_predict", type=str, help="") parser.add_argument("--transfer", default=None, type=str, help="") parser.add_argument("--fp16", action="store_true") parser.add_argument("--size", default=512, type=int) parser.add_argument("-s", "--scheduler", default="multistep", type=str, help="") parser.add_argument("-x", "--experiment", default=None, type=str, help="") parser.add_argument("-d", "--dropout", default=None, type=float, help="Dropout before head layer") parser.add_argument("--opl", action="store_true") parser.add_argument( "--warmup", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters" ) parser.add_argument("-wd", "--weight-decay", default=0, type=float, help="L2 weight decay") parser.add_argument("--show", action="store_true") parser.add_argument("--dsv", action="store_true") args = parser.parse_args() args.is_master = args.local_rank == 0 args.distributed = False fp16 = args.fp16 if "WORLD_SIZE" in os.environ: args.distributed = int(os.environ["WORLD_SIZE"]) > 1 args.world_size = int(os.environ["WORLD_SIZE"]) # args.world_size = torch.distributed.get_world_size() print("Initializing init_process_group", args.local_rank) torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend="nccl") print("Initialized init_process_group", args.local_rank) is_master = args.is_master | (not args.distributed) if args.distributed: distributed_params = {"rank": args.local_rank, "syncbn": True} if args.fp16: distributed_params["amp"] = True else: if args.fp16: distributed_params = {} distributed_params["amp"] = True else: distributed_params = False set_manual_seed(args.seed + args.local_rank) catalyst.utils.set_global_seed(args.seed + args.local_rank) torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = True data_dir = args.data_dir if data_dir is None: raise ValueError("--data-dir must be set") num_workers = args.workers num_epochs = args.epochs batch_size = args.batch_size learning_rate = args.learning_rate model_name = args.model optimizer_name = args.optimizer image_size = args.size, args.size fast = args.fast augmentations = args.augmentations train_mode = args.train_mode scheduler_name = args.scheduler experiment = args.experiment dropout = args.dropout online_pseudolabeling = args.opl criterions = args.criterion criterions2 = args.criterion2 criterions4 = args.criterion4 criterions8 = args.criterion8 criterions16 = args.criterion16 verbose = args.verbose show = args.show accumulation_steps = args.accumulation_steps weight_decay = args.weight_decay extra_data_xview2 = args.data_dir_xview2 run_train = num_epochs > 0 need_weight_mask = any(c[0] == "wbce" for c in criterions) custom_model_kwargs = {"full_size_mask": False} if dropout is not None: custom_model_kwargs["dropout"] = float(dropout) if any([criterions2, criterions4, criterions8, criterions16]): custom_model_kwargs["need_supervision_masks"] = True print("Enabling supervision masks") model: nn.Module = get_model(model_name, num_classes=16, **custom_model_kwargs).cuda() if args.transfer: transfer_checkpoint = fs.auto_file(args.transfer) print("Transfering weights from model checkpoint", transfer_checkpoint) checkpoint = load_checkpoint(transfer_checkpoint) pretrained_dict = checkpoint["model_state_dict"] transfer_weights(model, pretrained_dict) if args.checkpoint: checkpoint = load_checkpoint(fs.auto_file(args.checkpoint)) unpack_checkpoint(checkpoint, model=model) print("Loaded model weights from:", args.checkpoint) report_checkpoint(checkpoint) main_metric = "jaccard" current_time = datetime.now().strftime("%y%m%d_%H_%M") checkpoint_prefix = f"{current_time}_{args.model}" if fp16: checkpoint_prefix += "_fp16" if fast: checkpoint_prefix += "_fast" if online_pseudolabeling: checkpoint_prefix += "_opl" if extra_data_xview2: checkpoint_prefix += "_with_xview2" if experiment is not None: checkpoint_prefix = experiment default_callbacks = [ JaccardMetricPerImage( input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY, prefix="jaccard", inputs_to_labels=depth2mask, outputs_to_labels=decode_depth_mask, ), ] if is_master: default_callbacks += [ BestMetricCheckpointCallback(target_metric="jaccard", target_metric_minimize=False), HyperParametersCallback( hparam_dict={ "model": model_name, "scheduler": scheduler_name, "optimizer": optimizer_name, "augmentations": augmentations, "size": args.size, "weight_decay": weight_decay, "epochs": num_epochs, "dropout": None if dropout is None else float(dropout), } ), ] if show: visualize_inria_predictions = partial( draw_inria_predictions, image_key=INPUT_IMAGE_KEY, image_id_key=INPUT_IMAGE_ID_KEY, targets_key=INPUT_MASK_KEY, outputs_key=OUTPUT_MASK_KEY, inputs_to_labels=depth2mask, outputs_to_labels=decode_depth_mask, max_images=16, ) default_callbacks += [ ShowPolarBatchesCallback(visualize_inria_predictions, metric="accuracy", minimize=False), ShowPolarBatchesCallback(visualize_inria_predictions, metric="loss", minimize=True), ] train_ds, valid_ds, train_sampler = get_datasets( data_dir=data_dir, image_size=image_size, augmentation=augmentations, train_mode=train_mode, buildings_only=(train_mode == "tiles"), fast=fast, need_weight_mask=need_weight_mask, make_mask_target_fn=mask_to_ce_target, ) if extra_data_xview2 is not None: extra_train_ds, _ = get_xview2_extra_dataset( extra_data_xview2, image_size=image_size, augmentation=augmentations, fast=fast, need_weight_mask=need_weight_mask, ) weights = compute_sample_weight("balanced", [0] * len(train_ds) + [1] * len(extra_train_ds)) train_sampler = WeightedRandomSampler(weights, train_sampler.num_samples * 2) train_ds = train_ds + extra_train_ds print("Using extra data from xView2 with", len(extra_train_ds), "samples") if run_train: loaders = collections.OrderedDict() callbacks = default_callbacks.copy() criterions_dict = {} losses = [] ignore_index = None if online_pseudolabeling: ignore_index = UNLABELED_SAMPLE unlabeled_label = get_pseudolabeling_dataset( data_dir, include_masks=False, augmentation=None, image_size=image_size ) unlabeled_train = get_pseudolabeling_dataset( data_dir, include_masks=True, augmentation=augmentations, image_size=image_size ) if args.distributed: label_sampler = DistributedSampler(unlabeled_label, args.world_size, args.local_rank, shuffle=False) else: label_sampler = None loaders["infer"] = DataLoader( unlabeled_label, batch_size=batch_size // 2, num_workers=num_workers, pin_memory=True, sampler=label_sampler, drop_last=False, ) if train_sampler is not None: num_samples = 2 * train_sampler.num_samples else: num_samples = 2 * len(train_ds) weights = compute_sample_weight("balanced", [0] * len(train_ds) + [1] * len(unlabeled_label)) train_sampler = WeightedRandomSampler(weights, num_samples, replacement=True) train_ds = train_ds + unlabeled_train callbacks += [ BCEOnlinePseudolabelingCallback2d( unlabeled_train, pseudolabel_loader="infer", prob_threshold=0.7, output_key=OUTPUT_MASK_KEY, unlabeled_class=UNLABELED_SAMPLE, label_frequency=5, ) ] print("Using online pseudolabeling with ", len(unlabeled_label), "samples") valid_sampler = None if args.distributed: if train_sampler is not None: train_sampler = DistributedSamplerWrapper( train_sampler, args.world_size, args.local_rank, shuffle=True ) else: train_sampler = DistributedSampler(train_ds, args.world_size, args.local_rank, shuffle=True) valid_sampler = DistributedSampler(valid_ds, args.world_size, args.local_rank, shuffle=False) loaders["train"] = DataLoader( train_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True, drop_last=True, shuffle=train_sampler is None, sampler=train_sampler, ) loaders["valid"] = DataLoader( valid_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True, sampler=valid_sampler ) loss_callbacks, loss_criterions = get_criterions( criterions, criterions2, criterions4, criterions8, criterions16 ) callbacks += loss_callbacks optimizer = get_optimizer( optimizer_name, get_optimizable_parameters(model), learning_rate, weight_decay=weight_decay ) scheduler = get_scheduler( scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(loaders["train"]) ) if isinstance(scheduler, (CyclicLR, OneCycleLRWithWarmup)): callbacks += [SchedulerCallback(mode="batch")] log_dir = os.path.join("runs", checkpoint_prefix) if is_master: os.makedirs(log_dir, exist_ok=False) config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json") with open(config_fname, "w") as f: train_session_args = vars(args) f.write(json.dumps(train_session_args, indent=2)) print("Train session :", checkpoint_prefix) print(" FP16 mode :", fp16) print(" Fast mode :", args.fast) print(" Train mode :", train_mode) print(" Epochs :", num_epochs) print(" Workers :", num_workers) print(" Data dir :", data_dir) print(" Log dir :", log_dir) print(" Augmentations :", augmentations) print(" Train size :", "batches", len(loaders["train"]), "dataset", len(train_ds)) print(" Valid size :", "batches", len(loaders["valid"]), "dataset", len(valid_ds)) print("Model :", model_name) print(" Parameters :", count_parameters(model)) print(" Image size :", image_size) print("Optimizer :", optimizer_name) print(" Learning rate :", learning_rate) print(" Batch size :", batch_size) print(" Criterion :", criterions) print(" Use weight mask:", need_weight_mask) if args.distributed: print("Distributed") print(" World size :", args.world_size) print(" Local rank :", args.local_rank) print(" Is master :", args.is_master) # model training runner = SupervisedRunner(input_key=INPUT_IMAGE_KEY, output_key=None, device="cuda") runner.train( fp16=distributed_params, model=model, criterion=loss_criterions, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, "main"), num_epochs=num_epochs, verbose=verbose, main_metric=main_metric, minimize_metric=False, checkpoint_data={"cmd_args": vars(args)}, ) # Training is finished. Let's run predictions using best checkpoint weights if is_master: best_checkpoint = os.path.join(log_dir, "main", "checkpoints", "best.pth") model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}.pth") clean_checkpoint(best_checkpoint, model_checkpoint) unpack_checkpoint(torch.load(model_checkpoint), model=model) mask = predict( model, read_inria_image("sample_color.jpg"), image_size=image_size, batch_size=args.batch_size ) mask = ((mask > 0) * 255).astype(np.uint8) name = os.path.join(log_dir, "sample_color.jpg") cv2.imwrite(name, mask)
def trace_model_from_checkpoint( logdir: Path, method_name: str, checkpoint_name: str, stage: str = None, loader: Union[str, int] = None, mode: str = "eval", requires_grad: bool = False, opt_level: str = None, device: Device = "cpu", ): """ Traces model using created experiment and runner. Args: logdir (Union[str, Path]): Path to Catalyst logdir with model checkpoint_name (str): Name of model checkpoint to use stage (str): experiment's stage name loader (Union[str, int]): experiment's loader name or its index method_name (str): Model's method name that will be used as entrypoint during tracing mode (str): Mode for model to trace (``train`` or ``eval``) requires_grad (bool): Flag to use grads opt_level (str): AMP FP16 init level device (str): Torch device Returns: the traced model """ config_path = logdir / "configs" / "_config.json" checkpoint_path = logdir / "checkpoints" / f"{checkpoint_name}.pth" print("Load config") config: Dict[str, dict] = safitty.load(config_path) # Get expdir name config_expdir = safitty.get(config, "args", "expdir", apply=Path) # We will use copy of expdir from logs for reproducibility expdir = Path(logdir) / "code" / config_expdir.name print("Import experiment and runner from logdir") ExperimentType, RunnerType = import_experiment_and_runner(expdir) experiment: Experiment = ExperimentType(config) print(f"Load model state from checkpoints/{checkpoint_name}.pth") if stage is None: stage = list(experiment.stages)[0] model = experiment.get_model(stage) checkpoint = utils.load_checkpoint(checkpoint_path) utils.unpack_checkpoint(checkpoint, model=model) runner: RunnerType = RunnerType() runner.model, runner.device = model, device if loader is None: loader = 0 batch = experiment.get_native_batch(stage, loader) print("Tracing") traced = trace_model( model, runner, batch, method_name=method_name, mode=mode, requires_grad=requires_grad, opt_level=opt_level, device=device, ) print("Done") return traced
def main(): parser = argparse.ArgumentParser() parser.add_argument("-acc", "--accumulation-steps", type=int, default=1, help="Number of batches to process") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument("--fast", action="store_true") parser.add_argument("-dd", "--data-dir", type=str, required=True, help="Data directory for INRIA sattelite dataset") parser.add_argument("-dd-xview2", "--data-dir-xview2", type=str, required=False, help="Data directory for external xView2 dataset") parser.add_argument("-m", "--model", type=str, default="resnet34_fpncat128", help="") parser.add_argument("-b", "--batch-size", type=int, default=8, help="Batch Size during training, e.g. -b 64") parser.add_argument("-e", "--epochs", type=int, default=100, help="Epoch to run") # parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement') # parser.add_argument('-fe', '--freeze-encoder', type=int, default=0, help='Freeze encoder parameters for N epochs') # parser.add_argument('-ft', '--fine-tune', action='store_true') parser.add_argument("-lr", "--learning-rate", type=float, default=1e-3, help="Initial learning rate") parser.add_argument("-l", "--criterion", type=str, required=True, action="append", nargs="+", help="Criterion") parser.add_argument("-o", "--optimizer", default="RAdam", help="Name of the optimizer") parser.add_argument( "-c", "--checkpoint", type=str, default=None, help="Checkpoint filename to use as initial model weights") parser.add_argument("-w", "--workers", default=8, type=int, help="Num workers") parser.add_argument("-a", "--augmentations", default="hard", type=str, help="") parser.add_argument("-tm", "--train-mode", default="random", type=str, help="") parser.add_argument("--run-mode", default="fit_predict", type=str, help="") parser.add_argument("--transfer", default=None, type=str, help="") parser.add_argument("--fp16", action="store_true") parser.add_argument("--size", default=512, type=int) parser.add_argument("-s", "--scheduler", default="multistep", type=str, help="") parser.add_argument("-x", "--experiment", default=None, type=str, help="") parser.add_argument("-d", "--dropout", default=0.0, type=float, help="Dropout before head layer") parser.add_argument("--opl", action="store_true") parser.add_argument( "--warmup", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters") parser.add_argument("-wd", "--weight-decay", default=0, type=float, help="L2 weight decay") parser.add_argument("--show", action="store_true") parser.add_argument("--dsv", action="store_true") args = parser.parse_args() set_manual_seed(args.seed) data_dir = args.data_dir num_workers = args.workers num_epochs = args.epochs batch_size = args.batch_size learning_rate = args.learning_rate model_name = args.model optimizer_name = args.optimizer image_size = args.size, args.size fast = args.fast augmentations = args.augmentations train_mode = args.train_mode fp16 = args.fp16 scheduler_name = args.scheduler experiment = args.experiment dropout = args.dropout online_pseudolabeling = args.opl criterions = args.criterion verbose = args.verbose warmup = args.warmup show = args.show use_dsv = args.dsv accumulation_steps = args.accumulation_steps weight_decay = args.weight_decay extra_data_xview2 = args.data_dir_xview2 run_train = num_epochs > 0 need_weight_mask = any(c[0] == "wbce" for c in criterions) model: nn.Module = get_model(model_name, dropout=dropout).cuda() if args.transfer: transfer_checkpoint = fs.auto_file(args.transfer) print("Transfering weights from model checkpoint", transfer_checkpoint) checkpoint = load_checkpoint(transfer_checkpoint) pretrained_dict = checkpoint["model_state_dict"] transfer_weights(model, pretrained_dict) if args.checkpoint: checkpoint = load_checkpoint(fs.auto_file(args.checkpoint)) unpack_checkpoint(checkpoint, model=model) print("Loaded model weights from:", args.checkpoint) report_checkpoint(checkpoint) runner = SupervisedRunner(input_key=INPUT_IMAGE_KEY, output_key=None, device="cuda") main_metric = "optimized_jaccard" cmd_args = vars(args) current_time = datetime.now().strftime("%b%d_%H_%M") checkpoint_prefix = f"{current_time}_{args.model}" if fp16: checkpoint_prefix += "_fp16" if fast: checkpoint_prefix += "_fast" if online_pseudolabeling: checkpoint_prefix += "_opl" if extra_data_xview2: checkpoint_prefix += "_with_xview2" if experiment is not None: checkpoint_prefix = experiment log_dir = os.path.join("runs", checkpoint_prefix) os.makedirs(log_dir, exist_ok=False) config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json") with open(config_fname, "w") as f: train_session_args = vars(args) f.write(json.dumps(train_session_args, indent=2)) default_callbacks = [ PixelAccuracyCallback(input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY), JaccardMetricPerImage(input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY, prefix="jaccard"), OptimalThreshold(input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY, prefix="optimized_jaccard"), # OutputDistributionCallback(output_key=OUTPUT_MASK_KEY, activation=torch.sigmoid), ] if show: visualize_inria_predictions = partial( draw_inria_predictions, image_key=INPUT_IMAGE_KEY, image_id_key=INPUT_IMAGE_ID_KEY, targets_key=INPUT_MASK_KEY, outputs_key=OUTPUT_MASK_KEY, ) default_callbacks += [ ShowPolarBatchesCallback(visualize_inria_predictions, metric="accuracy", minimize=False) ] train_ds, valid_ds, train_sampler = get_datasets( data_dir=data_dir, image_size=image_size, augmentation=augmentations, train_mode=train_mode, fast=fast, need_weight_mask=need_weight_mask, ) if extra_data_xview2 is not None: extra_train_ds, _ = get_xview2_extra_dataset( extra_data_xview2, image_size=image_size, augmentation=augmentations, fast=fast, need_weight_mask=need_weight_mask, ) weights = compute_sample_weight("balanced", [0] * len(train_ds) + [1] * len(extra_train_ds)) train_sampler = WeightedRandomSampler(weights, train_sampler.num_samples * 2) train_ds = train_ds + extra_train_ds print("Using extra data from xView2 with", len(extra_train_ds), "samples") # Pretrain/warmup if warmup: callbacks = default_callbacks.copy() criterions_dict = {} losses = [] ignore_index = None for loss_name, loss_weight in criterions: criterion_callback = CriterionCallback( prefix="seg_loss/" + loss_name, input_key=INPUT_MASK_KEY if loss_name != "wbce" else [INPUT_MASK_KEY, INPUT_MASK_WEIGHT_KEY], output_key=OUTPUT_MASK_KEY, criterion_key=loss_name, multiplier=float(loss_weight), ) criterions_dict[loss_name] = get_loss(loss_name, ignore_index=ignore_index) callbacks.append(criterion_callback) losses.append(criterion_callback.prefix) print("Using loss", loss_name, loss_weight) callbacks += [ CriterionAggregatorCallback(prefix="loss", loss_keys=losses), OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False), ] parameters = get_lr_decay_parameters(model.named_parameters(), learning_rate, {"encoder": 0.1}) optimizer = get_optimizer("RAdam", parameters, learning_rate=learning_rate * 0.1) loaders = collections.OrderedDict() loaders["train"] = DataLoader( train_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True, drop_last=True, shuffle=train_sampler is None, sampler=train_sampler, ) loaders["valid"] = DataLoader(valid_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True, shuffle=False, drop_last=False) runner.train( fp16=fp16, model=model, criterion=criterions_dict, optimizer=optimizer, scheduler=None, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, "warmup"), num_epochs=warmup, verbose=verbose, main_metric=main_metric, minimize_metric=False, checkpoint_data={"cmd_args": cmd_args}, ) del optimizer, loaders best_checkpoint = os.path.join(log_dir, "warmup", "checkpoints", "best.pth") model_checkpoint = os.path.join(log_dir, "warmup", "checkpoints", f"{checkpoint_prefix}_warmup.pth") clean_checkpoint(best_checkpoint, model_checkpoint) torch.cuda.empty_cache() gc.collect() if run_train: loaders = collections.OrderedDict() callbacks = default_callbacks.copy() criterions_dict = {} losses = [] ignore_index = None if online_pseudolabeling: ignore_index = UNLABELED_SAMPLE unlabeled_label = get_pseudolabeling_dataset(data_dir, include_masks=False, augmentation=None, image_size=image_size) unlabeled_train = get_pseudolabeling_dataset( data_dir, include_masks=True, augmentation=augmentations, image_size=image_size) loaders["label"] = DataLoader(unlabeled_label, batch_size=batch_size // 2, num_workers=num_workers, pin_memory=True) if train_sampler is not None: num_samples = 2 * train_sampler.num_samples else: num_samples = 2 * len(train_ds) weights = compute_sample_weight("balanced", [0] * len(train_ds) + [1] * len(unlabeled_label)) train_sampler = WeightedRandomSampler(weights, num_samples, replacement=True) train_ds = train_ds + unlabeled_train callbacks += [ BCEOnlinePseudolabelingCallback2d( unlabeled_train, pseudolabel_loader="label", prob_threshold=0.7, output_key=OUTPUT_MASK_KEY, unlabeled_class=UNLABELED_SAMPLE, label_frequency=5, ) ] print("Using online pseudolabeling with ", len(unlabeled_label), "samples") loaders["train"] = DataLoader( train_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True, drop_last=True, shuffle=train_sampler is None, sampler=train_sampler, ) loaders["valid"] = DataLoader(valid_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True) # Create losses for loss_name, loss_weight in criterions: criterion_callback = CriterionCallback( prefix="seg_loss/" + loss_name, input_key=INPUT_MASK_KEY if loss_name != "wbce" else [INPUT_MASK_KEY, INPUT_MASK_WEIGHT_KEY], output_key=OUTPUT_MASK_KEY, criterion_key=loss_name, multiplier=float(loss_weight), ) criterions_dict[loss_name] = get_loss(loss_name, ignore_index=ignore_index) callbacks.append(criterion_callback) losses.append(criterion_callback.prefix) print("Using loss", loss_name, loss_weight) if use_dsv: print("Using DSV") criterions = "dsv" dsv_loss_name = "soft_bce" criterions_dict[criterions] = AdaptiveMaskLoss2d( get_loss(dsv_loss_name, ignore_index=ignore_index)) for i, dsv_input in enumerate([ OUTPUT_MASK_4_KEY, OUTPUT_MASK_8_KEY, OUTPUT_MASK_16_KEY, OUTPUT_MASK_32_KEY ]): criterion_callback = CriterionCallback( prefix="seg_loss_dsv/" + dsv_input, input_key=OUTPUT_MASK_KEY, output_key=dsv_input, criterion_key=criterions, multiplier=1.0, ) callbacks.append(criterion_callback) losses.append(criterion_callback.prefix) callbacks += [ CriterionAggregatorCallback(prefix="loss", loss_keys=losses), OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False), ] optimizer = get_optimizer(optimizer_name, get_optimizable_parameters(model), learning_rate, weight_decay=weight_decay) scheduler = get_scheduler(scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(loaders["train"])) if isinstance(scheduler, (CyclicLR, OneCycleLRWithWarmup)): callbacks += [SchedulerCallback(mode="batch")] print("Train session :", checkpoint_prefix) print("\tFP16 mode :", fp16) print("\tFast mode :", args.fast) print("\tTrain mode :", train_mode) print("\tEpochs :", num_epochs) print("\tWorkers :", num_workers) print("\tData dir :", data_dir) print("\tLog dir :", log_dir) print("\tAugmentations :", augmentations) print("\tTrain size :", len(loaders["train"]), len(train_ds)) print("\tValid size :", len(loaders["valid"]), len(valid_ds)) print("Model :", model_name) print("\tParameters :", count_parameters(model)) print("\tImage size :", image_size) print("Optimizer :", optimizer_name) print("\tLearning rate :", learning_rate) print("\tBatch size :", batch_size) print("\tCriterion :", criterions) print("\tUse weight mask:", need_weight_mask) # model training runner.train( fp16=fp16, model=model, criterion=criterions_dict, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, "main"), num_epochs=num_epochs, verbose=verbose, main_metric=main_metric, minimize_metric=False, checkpoint_data={"cmd_args": vars(args)}, ) # Training is finished. Let's run predictions using best checkpoint weights best_checkpoint = os.path.join(log_dir, "main", "checkpoints", "best.pth") model_checkpoint = os.path.join(log_dir, "main", "checkpoints", f"{checkpoint_prefix}.pth") clean_checkpoint(best_checkpoint, model_checkpoint) unpack_checkpoint(torch.load(model_checkpoint), model=model) mask = predict(model, read_inria_image("sample_color.jpg"), image_size=image_size, batch_size=args.batch_size) mask = ((mask > 0) * 255).astype(np.uint8) name = os.path.join(log_dir, "sample_color.jpg") cv2.imwrite(name, mask) del optimizer, loaders
def main(): parser = argparse.ArgumentParser() parser.add_argument("-m", "--model", type=str, default="unet", help="") parser.add_argument("-dd", "--data-dir", type=str, default=None, required=True, help="Data dir") parser.add_argument( "-c", "--checkpoint", type=str, default=None, required=True, help="Checkpoint filename to use as initial model weights", ) parser.add_argument("-b", "--batch-size", type=int, default=16, help="Batch size for inference") parser.add_argument("-tta", "--tta", default=None, type=str, help="Type of TTA to use [fliplr, d4]") args = parser.parse_args() data_dir = args.data_dir checkpoint_file = auto_file(args.checkpoint) run_dir = os.path.dirname(os.path.dirname(checkpoint_file)) out_dir = os.path.join(run_dir, "submit") os.makedirs(out_dir, exist_ok=True) checkpoint = load_checkpoint(checkpoint_file) checkpoint_epoch = checkpoint["epoch"] print("Loaded model weights from", args.checkpoint) print("Epoch :", checkpoint_epoch) print( "Metrics (Train):", "IoU:", checkpoint["epoch_metrics"]["train"]["jaccard"], "Acc:", checkpoint["epoch_metrics"]["train"]["accuracy"], ) print( "Metrics (Valid):", "IoU:", checkpoint["epoch_metrics"]["valid"]["jaccard"], "Acc:", checkpoint["epoch_metrics"]["valid"]["accuracy"], ) model = get_model(args.model) unpack_checkpoint(checkpoint, model=model) # threshold = checkpoint["epoch_metrics"]["valid"].get("optimized_jaccard/threshold", 0.5) threshold = 0.5 print("Using threshold", threshold) model = nn.Sequential(PickModelOutput(model, OUTPUT_MASK_KEY), nn.Sigmoid()) if args.tta == "fliplr": model = TTAWrapper(model, fliplr_image2mask) elif args.tta == "flipscale": model = TTAWrapper(model, fliplr_image2mask) model = MultiscaleTTAWrapper(model, size_offsets=[-128, -64, 64, 128]) elif args.tta == "d4": model = TTAWrapper(model, d4_image2mask) else: pass model = model.cuda() if torch.cuda.device_count() > 1: model = nn.DataParallel(model) model = model.eval() mask = predict(model, read_inria_image("sample_color.jpg"), image_size=(512, 512), batch_size=args.batch_size) mask = ((mask > threshold) * 255).astype(np.uint8) name = os.path.join(run_dir, "sample_color.jpg") cv2.imwrite(name, mask) test_predictions_dir = os.path.join(out_dir, "test_predictions") test_predictions_dir_compressed = os.path.join( out_dir, "test_predictions_compressed") if args.tta is not None: test_predictions_dir += f"_{args.tta}" test_predictions_dir_compressed += f"_{args.tta}" os.makedirs(test_predictions_dir, exist_ok=True) os.makedirs(test_predictions_dir_compressed, exist_ok=True) test_images = find_in_dir(os.path.join(data_dir, "test", "images")) for fname in tqdm(test_images, total=len(test_images)): image = read_inria_image(fname) mask = predict(model, image, image_size=(512, 512), batch_size=args.batch_size) mask = ((mask > threshold) * 255).astype(np.uint8) name = os.path.join(test_predictions_dir, os.path.basename(fname)) cv2.imwrite(name, mask) name_compressed = os.path.join(test_predictions_dir_compressed, os.path.basename(fname)) command = ( "gdal_translate --config GDAL_PAM_ENABLED NO -co COMPRESS=CCITTFAX4 -co NBITS=1 " + name + " " + name_compressed) subprocess.call(command, shell=True)
def main(args, _=None): """Run the ``catalyst-data text2embeddings`` script.""" batch_size = args.batch_size num_workers = args.num_workers max_length = args.max_length pooling_groups = args.pooling.split(",") bert_level = args.bert_level if bert_level is not None: assert (args.output_hidden_states ), "You need hidden states output for level specification" utils.set_global_seed(args.seed) utils.prepare_cudnn(args.deterministic, args.benchmark) if getattr(args, "in_huggingface", False): model_config = BertConfig.from_pretrained(args.in_huggingface) model_config.output_hidden_states = args.output_hidden_states model = BertModel.from_pretrained(args.in_huggingface, config=model_config) tokenizer = BertTokenizer.from_pretrained(args.in_huggingface) else: model_config = BertConfig.from_pretrained(args.in_config) model_config.output_hidden_states = args.output_hidden_states model = BertModel(config=model_config) tokenizer = BertTokenizer.from_pretrained(args.in_vocab) if getattr(args, "in_model", None) is not None: checkpoint = utils.load_checkpoint(args.in_model) checkpoint = {"model_state_dict": checkpoint} utils.unpack_checkpoint(checkpoint=checkpoint, model=model) model = model.eval() model, _, _, _, device = utils.process_components(model=model) df = pd.read_csv(args.in_csv) df = df.dropna(subset=[args.txt_col]) df.to_csv(f"{args.out_prefix}.df.csv", index=False) df = df.reset_index().drop("index", axis=1) df = list(df.to_dict("index").values()) num_samples = len(df) open_fn = LambdaReader( input_key=args.txt_col, output_key=None, lambda_fn=partial( tokenize_text, strip=args.strip, lowercase=args.lowercase, remove_punctuation=args.remove_punctuation, ), tokenizer=tokenizer, max_length=max_length, ) dataloader = utils.get_loader( df, open_fn, batch_size=batch_size, num_workers=num_workers, ) features = {} dataloader = tqdm(dataloader) if args.verbose else dataloader with torch.no_grad(): for idx, batch_input in enumerate(dataloader): batch_input = utils.any2device(batch_input, device) batch_output = model(**batch_input) mask = (batch_input["attention_mask"].unsqueeze(-1) if args.mask_for_max_length else None) if utils.check_ddp_wrapped(model): # using several gpu hidden_size = model.module.config.hidden_size hidden_states = model.module.config.output_hidden_states else: # using cpu or one gpu hidden_size = model.config.hidden_size hidden_states = model.config.output_hidden_states batch_features = process_bert_output( bert_output=batch_output, hidden_size=hidden_size, output_hidden_states=hidden_states, pooling_groups=pooling_groups, mask=mask, ) # create storage based on network output if idx == 0: for layer_name, layer_value in batch_features.items(): if bert_level is not None and bert_level != layer_name: continue layer_name = (layer_name if isinstance(layer_name, str) else f"{layer_name:02d}") _, embedding_size = layer_value.shape features[layer_name] = np.memmap( f"{args.out_prefix}.{layer_name}.npy", dtype=np.float32, mode="w+", shape=(num_samples, embedding_size), ) indices = np.arange(idx * batch_size, min((idx + 1) * batch_size, num_samples)) for layer_name2, layer_value2 in batch_features.items(): if bert_level is not None and bert_level != layer_name2: continue layer_name2 = (layer_name2 if isinstance(layer_name2, str) else f"{layer_name2:02d}") features[layer_name2][indices] = _detach(layer_value2) if args.force_save: for key, mmap in features.items(): mmap.flush() np.save(f"{args.out_prefix}.{key}.force.npy", mmap, allow_pickle=False)
def main(): parser = argparse.ArgumentParser() parser.add_argument("-acc", "--accumulation-steps", type=int, default=1, help="Number of batches to process") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument("--fast", action="store_true") parser.add_argument("-dd", "--data-dir", type=str, required=True, help="Data directory for INRIA sattelite dataset") parser.add_argument("-m", "--model", type=str, default="resnet34_fpncat128", help="") parser.add_argument("-b", "--batch-size", type=int, default=8, help="Batch Size during training, e.g. -b 64") parser.add_argument("-e", "--epochs", type=int, default=100, help="Epoch to run") # parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement') # parser.add_argument('-fe', '--freeze-encoder', type=int, default=0, help='Freeze encoder parameters for N epochs') # parser.add_argument('-ft', '--fine-tune', action='store_true') parser.add_argument("-lr", "--learning-rate", type=float, default=1e-3, help="Initial learning rate") parser.add_argument( "--disaster-type-loss", type=str, default=None, # [["ce", 1.0]], action="append", nargs="+", help="Criterion for classifying disaster type", ) parser.add_argument( "--damage-type-loss", type=str, default=None, # [["bce", 1.0]], action="append", nargs="+", help= "Criterion for classifying presence of building with particular damage type", ) parser.add_argument("-l", "--criterion", type=str, default=None, action="append", nargs="+", help="Criterion") parser.add_argument("--mask4", type=str, default=None, action="append", nargs="+", help="Criterion for mask with stride 4") parser.add_argument("--mask8", type=str, default=None, action="append", nargs="+", help="Criterion for mask with stride 8") parser.add_argument("--mask16", type=str, default=None, action="append", nargs="+", help="Criterion for mask with stride 16") parser.add_argument("--mask32", type=str, default=None, action="append", nargs="+", help="Criterion for mask with stride 32") parser.add_argument("--embedding", type=str, default=None) parser.add_argument("-o", "--optimizer", default="RAdam", help="Name of the optimizer") parser.add_argument( "-c", "--checkpoint", type=str, default=None, help="Checkpoint filename to use as initial model weights") parser.add_argument("-w", "--workers", default=8, type=int, help="Num workers") parser.add_argument("-a", "--augmentations", default="safe", type=str, help="Level of image augmentations") parser.add_argument("--transfer", default=None, type=str, help="") parser.add_argument("--fp16", action="store_true") parser.add_argument("--size", default=512, type=int) parser.add_argument("--fold", default=0, type=int) parser.add_argument("-s", "--scheduler", default="multistep", type=str, help="") parser.add_argument("-x", "--experiment", default=None, type=str, help="") parser.add_argument("-d", "--dropout", default=0.0, type=float, help="Dropout before head layer") parser.add_argument("-pl", "--pseudolabeling", type=str, required=True) parser.add_argument("-wd", "--weight-decay", default=0, type=float, help="L2 weight decay") parser.add_argument("--show", action="store_true") parser.add_argument("--dsv", action="store_true") parser.add_argument("--balance", action="store_true") parser.add_argument("--only-buildings", action="store_true") parser.add_argument("--freeze-bn", action="store_true") parser.add_argument("--crops", action="store_true", help="Train on random crops") parser.add_argument("--post-transform", action="store_true") args = parser.parse_args() set_manual_seed(args.seed) data_dir = args.data_dir num_workers = args.workers num_epochs = args.epochs learning_rate = args.learning_rate model_name = args.model optimizer_name = args.optimizer image_size = args.size, args.size fast = args.fast augmentations = args.augmentations fp16 = args.fp16 scheduler_name = args.scheduler experiment = args.experiment dropout = args.dropout segmentation_losses = args.criterion verbose = args.verbose show = args.show accumulation_steps = args.accumulation_steps weight_decay = args.weight_decay fold = args.fold balance = args.balance only_buildings = args.only_buildings freeze_bn = args.freeze_bn train_on_crops = args.crops enable_post_image_transform = args.post_transform disaster_type_loss = args.disaster_type_loss train_batch_size = args.batch_size embedding_criterion = args.embedding damage_type_loss = args.damage_type_loss pseudolabels_dir = args.pseudolabeling # Compute batch size for validaion if train_on_crops: valid_batch_size = max(1, (train_batch_size * (image_size[0] * image_size[1])) // (1024**2)) else: valid_batch_size = train_batch_size run_train = num_epochs > 0 model: nn.Module = get_model(model_name, dropout=dropout).cuda() if args.transfer: transfer_checkpoint = fs.auto_file(args.transfer) print("Transfering weights from model checkpoint", transfer_checkpoint) checkpoint = load_checkpoint(transfer_checkpoint) pretrained_dict = checkpoint["model_state_dict"] transfer_weights(model, pretrained_dict) if args.checkpoint: checkpoint = load_checkpoint(fs.auto_file(args.checkpoint)) unpack_checkpoint(checkpoint, model=model) print("Loaded model weights from:", args.checkpoint) report_checkpoint(checkpoint) if freeze_bn: torch_utils.freeze_bn(model) print("Freezing bn params") runner = SupervisedRunner(input_key=INPUT_IMAGE_KEY, output_key=None) main_metric = "weighted_f1" cmd_args = vars(args) current_time = datetime.now().strftime("%b%d_%H_%M") checkpoint_prefix = f"{current_time}_{args.model}_{args.size}_fold{fold}" if fp16: checkpoint_prefix += "_fp16" if fast: checkpoint_prefix += "_fast" if pseudolabels_dir: checkpoint_prefix += "_pseudo" if train_on_crops: checkpoint_prefix += "_crops" if experiment is not None: checkpoint_prefix = experiment log_dir = os.path.join("runs", checkpoint_prefix) os.makedirs(log_dir, exist_ok=False) config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json") with open(config_fname, "w") as f: train_session_args = vars(args) f.write(json.dumps(train_session_args, indent=2)) default_callbacks = [ CompetitionMetricCallback(input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY, prefix="weighted_f1"), ConfusionMatrixCallback( input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY, class_names=[ "land", "no_damage", "minor_damage", "major_damage", "destroyed" ], ignore_index=UNLABELED_SAMPLE, ), ] if show: default_callbacks += [ ShowPolarBatchesCallback(draw_predictions, metric=main_metric + "_batch", minimize=False) ] train_ds, valid_ds, train_sampler = get_datasets( data_dir=data_dir, image_size=image_size, augmentation=augmentations, fast=fast, fold=fold, balance=balance, only_buildings=only_buildings, train_on_crops=train_on_crops, crops_multiplication_factor=1, enable_post_image_transform=enable_post_image_transform, ) if run_train: loaders = collections.OrderedDict() callbacks = default_callbacks.copy() criterions_dict = {} losses = [] unlabeled_train = get_pseudolabeling_dataset( data_dir, include_masks=True, image_size=image_size, augmentation="medium_nmd", train_on_crops=train_on_crops, enable_post_image_transform=enable_post_image_transform, pseudolabels_dir=pseudolabels_dir, ) train_ds = train_ds + unlabeled_train print("Using online pseudolabeling with ", len(unlabeled_train), "samples") loaders["train"] = DataLoader( train_ds, batch_size=train_batch_size, num_workers=num_workers, pin_memory=True, drop_last=True, shuffle=True, ) loaders["valid"] = DataLoader(valid_ds, batch_size=valid_batch_size, num_workers=num_workers, pin_memory=True) # Create losses for criterion in segmentation_losses: if isinstance(criterion, (list, tuple)) and len(criterion) == 2: loss_name, loss_weight = criterion else: loss_name, loss_weight = criterion[0], 1.0 cd, criterion, criterion_name = get_criterion_callback( loss_name, prefix="segmentation", input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY, loss_weight=float(loss_weight), ) criterions_dict.update(cd) callbacks.append(criterion) losses.append(criterion_name) print(INPUT_MASK_KEY, "Using loss", loss_name, loss_weight) if args.mask4 is not None: for criterion in args.mask4: if isinstance(criterion, (list, tuple)): loss_name, loss_weight = criterion else: loss_name, loss_weight = criterion, 1.0 cd, criterion, criterion_name = get_criterion_callback( loss_name, prefix="mask4", input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_4_KEY, loss_weight=float(loss_weight), ) criterions_dict.update(cd) callbacks.append(criterion) losses.append(criterion_name) print(OUTPUT_MASK_4_KEY, "Using loss", loss_name, loss_weight) if args.mask8 is not None: for criterion in args.mask8: if isinstance(criterion, (list, tuple)): loss_name, loss_weight = criterion else: loss_name, loss_weight = criterion, 1.0 cd, criterion, criterion_name = get_criterion_callback( loss_name, prefix="mask8", input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_8_KEY, loss_weight=float(loss_weight), ) criterions_dict.update(cd) callbacks.append(criterion) losses.append(criterion_name) print(OUTPUT_MASK_8_KEY, "Using loss", loss_name, loss_weight) if args.mask16 is not None: for criterion in args.mask16: if isinstance(criterion, (list, tuple)): loss_name, loss_weight = criterion else: loss_name, loss_weight = criterion, 1.0 cd, criterion, criterion_name = get_criterion_callback( loss_name, prefix="mask16", input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_16_KEY, loss_weight=float(loss_weight), ) criterions_dict.update(cd) callbacks.append(criterion) losses.append(criterion_name) print(OUTPUT_MASK_16_KEY, "Using loss", loss_name, loss_weight) if args.mask32 is not None: for criterion in args.mask32: if isinstance(criterion, (list, tuple)): loss_name, loss_weight = criterion else: loss_name, loss_weight = criterion, 1.0 cd, criterion, criterion_name = get_criterion_callback( loss_name, prefix="mask32", input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_32_KEY, loss_weight=float(loss_weight), ) criterions_dict.update(cd) callbacks.append(criterion) losses.append(criterion_name) print(OUTPUT_MASK_32_KEY, "Using loss", loss_name, loss_weight) if disaster_type_loss is not None: callbacks += [ ConfusionMatrixCallback( input_key=DISASTER_TYPE_KEY, output_key=DISASTER_TYPE_KEY, class_names=DISASTER_TYPES, ignore_index=UNKNOWN_DISASTER_TYPE_CLASS, prefix=f"{DISASTER_TYPE_KEY}/confusion_matrix", ), AccuracyCallback( input_key=DISASTER_TYPE_KEY, output_key=DISASTER_TYPE_KEY, prefix=f"{DISASTER_TYPE_KEY}/accuracy", activation="Softmax", ), ] for criterion in disaster_type_loss: if isinstance(criterion, (list, tuple)): loss_name, loss_weight = criterion else: loss_name, loss_weight = criterion, 1.0 cd, criterion, criterion_name = get_criterion_callback( loss_name, prefix=DISASTER_TYPE_KEY, input_key=DISASTER_TYPE_KEY, output_key=DISASTER_TYPE_KEY, loss_weight=float(loss_weight), ignore_index=UNKNOWN_DISASTER_TYPE_CLASS, ) criterions_dict.update(cd) callbacks.append(criterion) losses.append(criterion_name) print(DISASTER_TYPE_KEY, "Using loss", loss_name, loss_weight) if damage_type_loss is not None: callbacks += [ # MultilabelConfusionMatrixCallback( # input_key=DAMAGE_TYPE_KEY, # output_key=DAMAGE_TYPE_KEY, # class_names=DAMAGE_TYPES, # prefix=f"{DAMAGE_TYPE_KEY}/confusion_matrix", # ), AccuracyCallback( input_key=DAMAGE_TYPE_KEY, output_key=DAMAGE_TYPE_KEY, prefix=f"{DAMAGE_TYPE_KEY}/accuracy", activation="Sigmoid", threshold=0.5, ) ] for criterion in damage_type_loss: if isinstance(criterion, (list, tuple)): loss_name, loss_weight = criterion else: loss_name, loss_weight = criterion, 1.0 cd, criterion, criterion_name = get_criterion_callback( loss_name, prefix=DAMAGE_TYPE_KEY, input_key=DAMAGE_TYPE_KEY, output_key=DAMAGE_TYPE_KEY, loss_weight=float(loss_weight), ) criterions_dict.update(cd) callbacks.append(criterion) losses.append(criterion_name) print(DAMAGE_TYPE_KEY, "Using loss", loss_name, loss_weight) if embedding_criterion is not None: cd, criterion, criterion_name = get_criterion_callback( embedding_criterion, prefix="embedding", input_key=INPUT_MASK_KEY, output_key=OUTPUT_EMBEDDING_KEY, loss_weight=1.0, ) criterions_dict.update(cd) callbacks.append(criterion) losses.append(criterion_name) print(OUTPUT_EMBEDDING_KEY, "Using loss", embedding_criterion) callbacks += [ CriterionAggregatorCallback(prefix="loss", loss_keys=losses), OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False), ] optimizer = get_optimizer(optimizer_name, get_optimizable_parameters(model), learning_rate, weight_decay=weight_decay) scheduler = get_scheduler(scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(loaders["train"])) if isinstance(scheduler, CyclicLR): callbacks += [SchedulerCallback(mode="batch")] print("Train session :", checkpoint_prefix) print(" FP16 mode :", fp16) print(" Fast mode :", args.fast) print(" Epochs :", num_epochs) print(" Workers :", num_workers) print(" Data dir :", data_dir) print(" Log dir :", log_dir) print("Data ") print(" Augmentations :", augmentations) print(" Train size :", len(loaders["train"]), len(train_ds)) print(" Valid size :", len(loaders["valid"]), len(valid_ds)) print(" Image size :", image_size) print(" Train on crops :", train_on_crops) print(" Balance :", balance) print(" Buildings only :", only_buildings) print(" Post transform :", enable_post_image_transform) print(" Pseudolabels :", pseudolabels_dir) print("Model :", model_name) print(" Parameters :", count_parameters(model)) print(" Dropout :", dropout) print("Optimizer :", optimizer_name) print(" Learning rate :", learning_rate) print(" Weight decay :", weight_decay) print(" Scheduler :", scheduler_name) print(" Batch sizes :", train_batch_size, valid_batch_size) print(" Criterion :", segmentation_losses) print(" Damage type :", damage_type_loss) print(" Disaster type :", disaster_type_loss) print(" Embedding :", embedding_criterion) # model training runner.train( fp16=fp16, model=model, criterion=criterions_dict, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, "opl"), num_epochs=num_epochs, verbose=verbose, main_metric=main_metric, minimize_metric=False, checkpoint_data={"cmd_args": cmd_args}, ) # Training is finished. Let's run predictions using best checkpoint weights best_checkpoint = os.path.join(log_dir, "main", "checkpoints", "best.pth") model_checkpoint = os.path.join(log_dir, "main", "checkpoints", f"{checkpoint_prefix}.pth") clean_checkpoint(best_checkpoint, model_checkpoint) del optimizer, loaders
def main(args): if args.wandb: import wandb wandb.init() logdir = args.logdir + "/" + wandb.run.name else: logdir = args.logdir set_global_seed(args.seed) datasets = load_dataset(args.dataset) tokenizer = AutoTokenizer.from_pretrained(args.teacher_model) datasets = datasets.map( lambda e: tokenizer( e["text"], truncation=True, padding="max_length", max_length=128), batched=True, ) datasets = datasets.map(lambda e: {"labels": e["label"]}, batched=True) datasets.set_format( type="torch", columns=["input_ids", "token_type_ids", "attention_mask", "labels"], ) loaders = { "train": DataLoader(datasets["train"], batch_size=args.batch_size, shuffle=True), "valid": DataLoader(datasets["test"], batch_size=args.batch_size), } teacher_model = AutoModelForSequenceClassification.from_pretrained( args.teacher_model, num_labels=args.num_labels) unpack_checkpoint(torch.load(args.teacher_path), model=teacher_model) metric_callback = LoaderMetricCallback( metric=HFMetric(metric=load_metric("accuracy")), input_key="s_logits", target_key="labels", ) layers = [int(layer) for layer in args.layers.split(",")] slct_callback = ControlFlowCallback( HiddenStatesSelectCallback(hiddens_key="t_hidden_states", layers=layers), loaders="train", ) lambda_hiddens_callback = ControlFlowCallback( LambdaPreprocessCallback(lambda s_hiddens, t_hiddens: ( [c_s[:, 0] for c_s in s_hiddens], [t_s[:, 0] for t_s in t_hiddens], # tooks only CLS token )), loaders="train", ) mse_hiddens = ControlFlowCallback(MSEHiddenStatesCallback(), loaders="train") kl_div = ControlFlowCallback( KLDivCallback(temperature=args.kl_temperature), loaders="train") runner = HFDistilRunner() student_model = AutoModelForSequenceClassification.from_pretrained( args.student_model, num_labels=args.num_labels) callbacks = [ metric_callback, slct_callback, lambda_hiddens_callback, kl_div, OptimizerCallback(metric_key="loss"), CheckpointCallback(logdir=logdir, loader_key="valid", mode="model", metric_key="accuracy", minimize=False) ] if args.beta > 0: aggregator = ControlFlowCallback( MetricAggregationCallback( prefix="loss", metrics={ "kl_div_loss": args.alpha, "mse_loss": args.beta, "task_loss": 1 - args.alpha }, mode="weighted_sum", ), loaders="train", ) callbacks.append(mse_hiddens) callbacks.append(aggregator) else: aggregator = ControlFlowCallback( MetricAggregationCallback( prefix="loss", metrics={ "kl_div_loss": args.alpha, "task_loss": 1 - args.alpha }, mode="weighted_sum", ), loaders="train", ) callbacks.append(aggregator) runner.train(model=torch.nn.ModuleDict({ "teacher": teacher_model, "student": student_model }), loaders=loaders, optimizer=torch.optim.Adam(student_model.parameters(), lr=args.lr), callbacks=callbacks, num_epochs=args.num_epochs, valid_metric="accuracy", logdir=logdir, minimize_valid_metric=False, valid_loader="valid", verbose=args.verbose, seed=args.seed) if args.wandb: import csv import shutil with open(logdir + "/valid.csv") as fi: reader = csv.DictReader(fi) accuracy = [] for row in reader: if row["accuracy"] == "accuracy": continue accuracy.append(float(row["accuracy"])) wandb.log({"accuracy": max(accuracy[-args.num_epochs:])}) shutil.rmtree(logdir)
def run_stage_training(model: Union[TimmRgbModel, YCrCbModel], config: StageConfig, exp_config: ExperimenetConfig, experiment_dir: str): # Preparing model freeze_model(model, freeze_bn=config.freeze_bn) train_ds, valid_ds, train_sampler = get_datasets( data_dir=exp_config.data_dir, image_size=config.image_size, augmentation=config.augmentations, balance=config.balance, fast=config.fast, fold=exp_config.fold, features=model.required_features, obliterate_p=config.obliterate_p, ) criterions_dict, loss_callbacks = get_criterions( modification_flag=config.modification_flag_loss, modification_type=config.modification_type_loss, embedding_loss=config.embedding_loss, feature_maps_loss=config.feature_maps_loss, num_epochs=config.epochs, mixup=config.mixup, cutmix=config.cutmix, tsa=config.tsa, ) callbacks = loss_callbacks + [ OptimizerCallback(accumulation_steps=config.accumulation_steps, decouple_weight_decay=False), HyperParametersCallback( hparam_dict={ "model": exp_config.model_name, "scheduler": config.schedule, "optimizer": config.optimizer, "augmentations": config.augmentations, "size": config.image_size[0], "weight_decay": config.weight_decay, }), ] if config.show: callbacks += [ ShowPolarBatchesCallback(draw_predictions, metric="loss", minimize=True) ] loaders = collections.OrderedDict() loaders["train"] = DataLoader( train_ds, batch_size=config.train_batch_size, num_workers=exp_config.num_workers, pin_memory=True, drop_last=True, shuffle=train_sampler is None, sampler=train_sampler, ) loaders["valid"] = DataLoader(valid_ds, batch_size=config.valid_batch_size, num_workers=exp_config.num_workers, pin_memory=True) print("Stage :", config.stage_name) print(" FP16 mode :", config.fp16) print(" Fast mode :", config.fast) print(" Epochs :", config.epochs) print(" Workers :", exp_config.num_workers) print(" Data dir :", exp_config.data_dir) print(" Experiment dir :", experiment_dir) print("Data ") print(" Augmentations :", config.augmentations) print(" Obliterate (%) :", config.obliterate_p) print(" Negative images:", config.negative_image_dir) print(" Train size :", len(loaders["train"]), "batches", len(train_ds), "samples") print(" Valid size :", len(loaders["valid"]), "batches", len(valid_ds), "samples") print(" Image size :", config.image_size) print(" Balance :", config.balance) print(" Mixup :", config.mixup) print(" CutMix :", config.cutmix) print(" TSA :", config.tsa) print("Model :", exp_config.model_name) print(" Parameters :", count_parameters(model)) print(" Dropout :", exp_config.dropout) print("Optimizer :", config.optimizer) print(" Learning rate :", config.learning_rate) print(" Weight decay :", config.weight_decay) print(" Scheduler :", config.schedule) print(" Batch sizes :", config.train_batch_size, config.valid_batch_size) print("Losses ") print(" Flag :", config.modification_flag_loss) print(" Type :", config.modification_type_loss) print(" Embedding :", config.embedding_loss) print(" Feature maps :", config.feature_maps_loss) optimizer = get_optimizer( config.optimizer, get_optimizable_parameters(model), learning_rate=config.learning_rate, weight_decay=config.weight_decay, ) scheduler = get_scheduler( config.schedule, optimizer, lr=config.learning_rate, num_epochs=config.epochs, batches_in_epoch=len(loaders["train"]), ) if isinstance(scheduler, CyclicLR): callbacks += [SchedulerCallback(mode="batch")] # model training runner = SupervisedRunner(input_key=model.required_features, output_key=None) runner.train( fp16=config.fp16, model=model, criterion=criterions_dict, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, logdir=os.path.join(experiment_dir, config.stage_name), num_epochs=config.epochs, verbose=config.verbose, main_metric=config.main_metric, minimize_metric=config.main_metric_minimize, checkpoint_data={"config": config}, ) del optimizer, loaders, callbacks, runner best_checkpoint = os.path.join(experiment_dir, config.stage_name, "checkpoints", "best.pth") model_checkpoint = os.path.join(experiment_dir, f"{exp_config.checkpoint_prefix}.pth") clean_checkpoint(best_checkpoint, model_checkpoint) # Restore state of best model if config.restore_best: unpack_checkpoint(load_checkpoint(model_checkpoint), model=model) # Some memory cleanup torch.cuda.empty_cache() gc.collect()
if args.multigpu: model = nn.DataParallel(model) if args.task == 'segmentation': callbacks = [DiceCallback(), EarlyStoppingCallback(patience=10, min_delta=0.001), CriterionCallback()] elif args.task == 'classification': callbacks = [AUCCallback(class_names=['Fish', 'Flower', 'Gravel', 'Sugar'], num_classes=4), EarlyStoppingCallback(patience=10, min_delta=0.001), CriterionCallback()] if args.gradient_accumulation: callbacks.append(OptimizerCallback(accumulation_steps=args.gradient_accumulation)) checkpoint = utils.load_checkpoint(f'{logdir}/checkpoints/best.pth') model.cuda() utils.unpack_checkpoint(checkpoint, model=model) # # runner = SupervisedRunner() if args.train: print('Training') runner.train( model=model, criterion=criterion, optimizer=optimizer, main_metric='dice', minimize_metric=False, scheduler=scheduler, loaders=loaders, callbacks=callbacks, logdir=logdir,
def main(): parser = argparse.ArgumentParser() parser.add_argument('--seed', type=int, default=42, help='Random seed') parser.add_argument('--fast', action='store_true') parser.add_argument('-dd', '--data-dir', type=str, default='data', help='Data directory for INRIA sattelite dataset') parser.add_argument('-m', '--model', type=str, default='cls_resnet18', help='') parser.add_argument('-b', '--batch-size', type=int, default=8, help='Batch Size during training, e.g. -b 64') parser.add_argument('-e', '--epochs', type=int, default=100, help='Epoch to run') parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement') parser.add_argument('-fe', '--freeze-encoder', action='store_true') parser.add_argument('-lr', '--learning-rate', type=float, default=1e-4, help='Initial learning rate') parser.add_argument('-l', '--criterion', type=str, default='bce', help='Criterion') parser.add_argument('-o', '--optimizer', default='Adam', help='Name of the optimizer') parser.add_argument( '-c', '--checkpoint', type=str, default=None, help='Checkpoint filename to use as initial model weights') parser.add_argument('-w', '--workers', default=multiprocessing.cpu_count(), type=int, help='Num workers') parser.add_argument('-a', '--augmentations', default='hard', type=str, help='') parser.add_argument('-tta', '--tta', default=None, type=str, help='Type of TTA to use [fliplr, d4]') parser.add_argument('-tm', '--train-mode', default='random', type=str, help='') parser.add_argument('-rm', '--run-mode', default='fit_predict', type=str, help='') parser.add_argument('--transfer', default=None, type=str, help='') parser.add_argument('--fp16', action='store_true') args = parser.parse_args() set_manual_seed(args.seed) data_dir = args.data_dir num_workers = args.workers num_epochs = args.epochs batch_size = args.batch_size learning_rate = args.learning_rate early_stopping = args.early_stopping model_name = args.model optimizer_name = args.optimizer image_size = (512, 512) fast = args.fast augmentations = args.augmentations train_mode = args.train_mode run_mode = args.run_mode log_dir = None fp16 = args.fp16 freeze_encoder = args.freeze_encoder run_train = run_mode == 'fit_predict' or run_mode == 'fit' run_predict = run_mode == 'fit_predict' or run_mode == 'predict' model = maybe_cuda(get_model(model_name, num_classes=1)) if args.transfer: transfer_checkpoint = fs.auto_file(args.transfer) print("Transfering weights from model checkpoint", transfer_checkpoint) checkpoint = load_checkpoint(transfer_checkpoint) pretrained_dict = checkpoint['model_state_dict'] for name, value in pretrained_dict.items(): try: model.load_state_dict(collections.OrderedDict([(name, value)]), strict=False) except Exception as e: print(e) checkpoint = None if args.checkpoint: checkpoint = load_checkpoint(fs.auto_file(args.checkpoint)) unpack_checkpoint(checkpoint, model=model) checkpoint_epoch = checkpoint['epoch'] print('Loaded model weights from:', args.checkpoint) print('Epoch :', checkpoint_epoch) print('Metrics (Train):', 'f1 :', checkpoint['epoch_metrics']['train']['f1_score'], 'loss:', checkpoint['epoch_metrics']['train']['loss']) print('Metrics (Valid):', 'f1 :', checkpoint['epoch_metrics']['valid']['f1_score'], 'loss:', checkpoint['epoch_metrics']['valid']['loss']) log_dir = os.path.dirname( os.path.dirname(fs.auto_file(args.checkpoint))) if run_train: if freeze_encoder: set_trainable(model.encoder, trainable=False, freeze_bn=True) criterion = get_loss(args.criterion) parameters = get_optimizable_parameters(model) optimizer = get_optimizer(optimizer_name, parameters, learning_rate) if checkpoint is not None: try: unpack_checkpoint(checkpoint, optimizer=optimizer) print('Restored optimizer state from checkpoint') except Exception as e: print('Failed to restore optimizer state from checkpoint', e) train_loader, valid_loader = get_dataloaders( data_dir=data_dir, batch_size=batch_size, num_workers=num_workers, image_size=image_size, augmentation=augmentations, fast=fast) loaders = collections.OrderedDict() loaders["train"] = train_loader loaders["valid"] = valid_loader current_time = datetime.now().strftime('%b%d_%H_%M') prefix = f'adversarial/{args.model}/{current_time}_{args.criterion}' if fp16: prefix += '_fp16' if fast: prefix += '_fast' log_dir = os.path.join('runs', prefix) os.makedirs(log_dir, exist_ok=False) scheduler = MultiStepLR(optimizer, milestones=[10, 30, 50, 70, 90], gamma=0.5) print('Train session :', prefix) print('\tFP16 mode :', fp16) print('\tFast mode :', args.fast) print('\tTrain mode :', train_mode) print('\tEpochs :', num_epochs) print('\tEarly stopping :', early_stopping) print('\tWorkers :', num_workers) print('\tData dir :', data_dir) print('\tLog dir :', log_dir) print('\tAugmentations :', augmentations) print('\tTrain size :', len(train_loader), len(train_loader.dataset)) print('\tValid size :', len(valid_loader), len(valid_loader.dataset)) print('Model :', model_name) print('\tParameters :', count_parameters(model)) print('\tImage size :', image_size) print('\tFreeze encoder :', freeze_encoder) print('Optimizer :', optimizer_name) print('\tLearning rate :', learning_rate) print('\tBatch size :', batch_size) print('\tCriterion :', args.criterion) # model training visualization_fn = partial(draw_classification_predictions, class_names=['Train', 'Test']) callbacks = [ F1ScoreCallback(), AUCCallback(), ShowPolarBatchesCallback(visualization_fn, metric='f1_score', minimize=False), ] if early_stopping: callbacks += [ EarlyStoppingCallback(early_stopping, metric='auc', minimize=False) ] runner = SupervisedRunner(input_key='image') runner.train(fp16=fp16, model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, logdir=log_dir, num_epochs=num_epochs, verbose=True, main_metric='auc', minimize_metric=False, state_kwargs={"cmd_args": vars(args)}) if run_predict and not fast: # Training is finished. Let's run predictions using best checkpoint weights best_checkpoint = load_checkpoint( fs.auto_file('best.pth', where=log_dir)) unpack_checkpoint(best_checkpoint, model=model) model.eval() torch.no_grad() train_csv = pd.read_csv(os.path.join(data_dir, 'train.csv')) train_csv['id_code'] = train_csv['id_code'].apply( lambda x: os.path.join(data_dir, 'train_images', f'{x}.png')) test_ds = RetinopathyDataset(train_csv['id_code'], None, get_test_aug(image_size), target_as_array=True) test_dl = DataLoader(test_ds, batch_size, pin_memory=True, num_workers=num_workers) test_ids = [] test_preds = [] for batch in tqdm(test_dl, desc='Inference'): input = batch['image'].cuda() outputs = model(input) predictions = to_numpy(outputs['logits'].sigmoid().squeeze(1)) test_ids.extend(batch['image_id']) test_preds.extend(predictions) df = pd.DataFrame.from_dict({ 'id_code': test_ids, 'is_test': test_preds }) df.to_csv(os.path.join(log_dir, 'test_in_train.csv'), index=None)
def main(): parser = argparse.ArgumentParser() parser.add_argument("-acc", "--accumulation-steps", type=int, default=1, help="Number of batches to process") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("--obliterate", type=float, default=0, help="Change of obliteration") parser.add_argument("-nid", "--negative-image-dir", type=str, default=None, help="Change of obliteration") parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument("--fast", action="store_true") parser.add_argument("--cache", action="store_true") parser.add_argument("-dd", "--data-dir", type=str, default=os.environ.get("KAGGLE_2020_ALASKA2")) parser.add_argument("-m", "--model", type=str, default="resnet34", help="") parser.add_argument("-b", "--batch-size", type=int, default=16, help="Batch Size during training, e.g. -b 64") parser.add_argument( "-wbs", "--warmup-batch-size", type=int, default=None, help="Batch Size during training, e.g. -b 64" ) parser.add_argument("-e", "--epochs", type=int, default=100, help="Epoch to run") parser.add_argument( "-es", "--early-stopping", type=int, default=None, help="Maximum number of epochs without improvement" ) parser.add_argument("-fe", "--freeze-encoder", action="store_true", help="Freeze encoder parameters for N epochs") parser.add_argument("-lr", "--learning-rate", type=float, default=1e-3, help="Initial learning rate") parser.add_argument( "-l", "--modification-flag-loss", type=str, default=None, action="append", nargs="+" # [["ce", 1.0]], ) parser.add_argument( "--modification-type-loss", type=str, default=None, action="append", nargs="+" # [["ce", 1.0]], ) parser.add_argument("--embedding-loss", type=str, default=None, action="append", nargs="+") # [["ce", 1.0]], parser.add_argument("--feature-maps-loss", type=str, default=None, action="append", nargs="+") # [["ce", 1.0]], parser.add_argument("--mask-loss", type=str, default=None, action="append", nargs="+") # [["ce", 1.0]], parser.add_argument("--bits-loss", type=str, default=None, action="append", nargs="+") # [["ce", 1.0]], parser.add_argument("-o", "--optimizer", default="RAdam", help="Name of the optimizer") parser.add_argument( "-c", "--checkpoint", type=str, default=None, help="Checkpoint filename to use as initial model weights" ) parser.add_argument("-w", "--workers", default=8, type=int, help="Num workers") parser.add_argument("-a", "--augmentations", default="safe", type=str, help="Level of image augmentations") parser.add_argument("--transfer", default=None, type=str, help="") parser.add_argument("--fp16", action="store_true") parser.add_argument("--mixup", action="store_true") parser.add_argument("--cutmix", action="store_true") parser.add_argument("--tsa", action="store_true") parser.add_argument("--fold", default=None, type=int) parser.add_argument("-s", "--scheduler", default=None, type=str, help="") parser.add_argument("-x", "--experiment", default=None, type=str, help="") parser.add_argument("-d", "--dropout", default=None, type=float, help="Dropout before head layer") parser.add_argument( "--warmup", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters" ) parser.add_argument( "--fine-tune", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters" ) parser.add_argument("-wd", "--weight-decay", default=0, type=float, help="L2 weight decay") parser.add_argument("--show", action="store_true") parser.add_argument("--balance", action="store_true") parser.add_argument("--freeze-bn", action="store_true") args = parser.parse_args() set_manual_seed(args.seed) assert ( args.modification_flag_loss or args.modification_type_loss or args.embedding_loss ), "At least one of losses must be set" modification_flag_loss = args.modification_flag_loss modification_type_loss = args.modification_type_loss embedding_loss = args.embedding_loss feature_maps_loss = args.feature_maps_loss mask_loss = args.mask_loss bits_loss = args.bits_loss freeze_encoder = args.freeze_encoder data_dir = args.data_dir cache = args.cache num_workers = args.workers num_epochs = args.epochs learning_rate = args.learning_rate model_name: str = args.model optimizer_name = args.optimizer image_size = (512, 512) fast = args.fast augmentations = args.augmentations fp16 = args.fp16 scheduler_name = args.scheduler experiment = args.experiment dropout = args.dropout verbose = args.verbose warmup = args.warmup show = args.show accumulation_steps = args.accumulation_steps weight_decay = args.weight_decay fold = args.fold balance = args.balance freeze_bn = args.freeze_bn train_batch_size = args.batch_size mixup = args.mixup cutmix = args.cutmix tsa = args.tsa fine_tune = args.fine_tune obliterate_p = args.obliterate negative_image_dir = args.negative_image_dir warmup_batch_size = args.warmup_batch_size or args.batch_size # Compute batch size for validation valid_batch_size = train_batch_size run_train = num_epochs > 0 custom_model_kwargs = {} if dropout is not None: custom_model_kwargs["dropout"] = float(dropout) if embedding_loss is not None: custom_model_kwargs["need_embedding"] = True model: nn.Module = get_model(model_name, **custom_model_kwargs).cuda() required_features = model.required_features if mask_loss is not None: required_features.append(INPUT_TRUE_MODIFICATION_MASK) if args.transfer: transfer_checkpoint = fs.auto_file(args.transfer) print("Transferring weights from model checkpoint", transfer_checkpoint) checkpoint = load_checkpoint(transfer_checkpoint) pretrained_dict = checkpoint["model_state_dict"] transfer_weights(model, pretrained_dict) if args.checkpoint: checkpoint = load_checkpoint(fs.auto_file(args.checkpoint)) unpack_checkpoint(checkpoint, model=model) print("Loaded model weights from:", args.checkpoint) report_checkpoint(checkpoint) if freeze_bn: from pytorch_toolbelt.optimization.functional import freeze_model freeze_model(model, freeze_bn=True) print("Freezing bn params") main_metric = "loss" main_metric_minimize = True current_time = datetime.now().strftime("%b%d_%H_%M") checkpoint_prefix = f"{current_time}_{args.model}_fold{fold}" if fp16: checkpoint_prefix += "_fp16" if fast: checkpoint_prefix += "_fast" if mixup: checkpoint_prefix += "_mixup" if cutmix: checkpoint_prefix += "_cutmix" if experiment is not None: checkpoint_prefix = experiment log_dir = os.path.join("runs", checkpoint_prefix) os.makedirs(log_dir, exist_ok=False) config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json") with open(config_fname, "w") as f: train_session_args = vars(args) f.write(json.dumps(train_session_args, indent=2)) default_callbacks = [] if show: default_callbacks += [ShowPolarBatchesCallback(draw_predictions, metric="loss", minimize=True)] # Pretrain/warmup if warmup: train_ds, valid_ds, train_sampler = get_datasets( data_dir=data_dir, augmentation=augmentations, balance=balance, fast=fast, fold=fold, features=required_features, obliterate_p=0, ) criterions_dict, loss_callbacks = get_criterions( modification_flag=modification_flag_loss, modification_type=modification_type_loss, embedding_loss=embedding_loss, mask_loss=mask_loss, bits_loss=bits_loss, feature_maps_loss=feature_maps_loss, num_epochs=warmup, mixup=mixup, cutmix=cutmix, tsa=tsa, ) callbacks = ( default_callbacks + loss_callbacks + [ OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False), HyperParametersCallback( hparam_dict={ "model": model_name, "scheduler": scheduler_name, "optimizer": optimizer_name, "augmentations": augmentations, "size": image_size[0], "weight_decay": weight_decay, } ), ] ) loaders = collections.OrderedDict() loaders["train"] = DataLoader( train_ds, batch_size=warmup_batch_size, num_workers=num_workers, pin_memory=True, drop_last=True, shuffle=train_sampler is None, sampler=train_sampler, ) loaders["valid"] = DataLoader(valid_ds, batch_size=warmup_batch_size, num_workers=num_workers, pin_memory=True) if freeze_encoder: from pytorch_toolbelt.optimization.functional import freeze_model freeze_model(model.encoder, freeze_parameters=True, freeze_bn=None) optimizer = get_optimizer( "Ranger", get_optimizable_parameters(model), weight_decay=weight_decay, learning_rate=3e-4 ) scheduler = None print("Train session :", checkpoint_prefix) print(" FP16 mode :", fp16) print(" Fast mode :", args.fast) print(" Epochs :", num_epochs) print(" Workers :", num_workers) print(" Data dir :", data_dir) print(" Log dir :", log_dir) print(" Cache :", cache) print("Data ") print(" Augmentations :", augmentations) print(" Negative images:", negative_image_dir) print(" Train size :", len(loaders["train"]), "batches", len(train_ds), "samples") print(" Valid size :", len(loaders["valid"]), "batches", len(valid_ds), "samples") print(" Image size :", image_size) print(" Balance :", balance) print(" Mixup :", mixup) print(" CutMix :", cutmix) print(" TSA :", tsa) print("Model :", model_name) print(" Parameters :", count_parameters(model)) print(" Dropout :", dropout, "(Non-default)" if dropout is not None else "") print("Optimizer :", optimizer_name) print(" Learning rate :", learning_rate) print(" Weight decay :", weight_decay) print(" Scheduler :", scheduler_name) print(" Batch sizes :", train_batch_size, valid_batch_size) print("Losses ") print(" Flag :", modification_flag_loss) print(" Type :", modification_type_loss) print(" Embedding :", embedding_loss) print(" Feature maps :", feature_maps_loss) print(" Mask :", mask_loss) print(" Bits :", bits_loss) runner = SupervisedRunner(input_key=required_features, output_key=None) runner.train( fp16=fp16, model=model, criterion=criterions_dict, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, "warmup"), num_epochs=warmup, verbose=verbose, main_metric=main_metric, minimize_metric=main_metric_minimize, checkpoint_data={"cmd_args": vars(args)}, ) del optimizer, loaders, runner, callbacks best_checkpoint = os.path.join(log_dir, "warmup", "checkpoints", "best.pth") model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}_warmup.pth") clean_checkpoint(best_checkpoint, model_checkpoint) # Restore state of best model # unpack_checkpoint(load_checkpoint(model_checkpoint), model=model) torch.cuda.empty_cache() gc.collect() if run_train: train_ds, valid_ds, train_sampler = get_datasets( data_dir=data_dir, augmentation=augmentations, balance=balance, fast=fast, fold=fold, features=required_features, obliterate_p=obliterate_p, ) if negative_image_dir: negatives_ds = get_negatives_ds( negative_image_dir, fold=fold, features=required_features, max_images=16536 ) train_ds = train_ds + negatives_ds train_sampler = None # TODO: Add proper support of sampler print("Adding", len(negatives_ds), "negative samples to training set") criterions_dict, loss_callbacks = get_criterions( modification_flag=modification_flag_loss, modification_type=modification_type_loss, embedding_loss=embedding_loss, feature_maps_loss=feature_maps_loss, mask_loss=mask_loss, bits_loss=bits_loss, num_epochs=num_epochs, mixup=mixup, cutmix=cutmix, tsa=tsa, ) callbacks = ( default_callbacks + loss_callbacks + [ OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False), HyperParametersCallback( hparam_dict={ "model": model_name, "scheduler": scheduler_name, "optimizer": optimizer_name, "augmentations": augmentations, "size": image_size[0], "weight_decay": weight_decay, } ), ] ) loaders = collections.OrderedDict() loaders["train"] = DataLoader( train_ds, batch_size=train_batch_size, num_workers=num_workers, pin_memory=True, drop_last=True, shuffle=train_sampler is None, sampler=train_sampler, ) loaders["valid"] = DataLoader(valid_ds, batch_size=valid_batch_size, num_workers=num_workers, pin_memory=True) print("Train session :", checkpoint_prefix) print(" FP16 mode :", fp16) print(" Fast mode :", args.fast) print(" Epochs :", num_epochs) print(" Workers :", num_workers) print(" Data dir :", data_dir) print(" Log dir :", log_dir) print(" Cache :", cache) print("Data ") print(" Augmentations :", augmentations) print(" Obliterate (%) :", obliterate_p) print(" Negative images:", negative_image_dir) print(" Train size :", len(loaders["train"]), "batches", len(train_ds), "samples") print(" Valid size :", len(loaders["valid"]), "batches", len(valid_ds), "samples") print(" Image size :", image_size) print(" Balance :", balance) print(" Mixup :", mixup) print(" CutMix :", cutmix) print(" TSA :", tsa) print("Model :", model_name) print(" Parameters :", count_parameters(model)) print(" Dropout :", dropout) print("Optimizer :", optimizer_name) print(" Learning rate :", learning_rate) print(" Weight decay :", weight_decay) print(" Scheduler :", scheduler_name) print(" Batch sizes :", train_batch_size, valid_batch_size) print("Losses ") print(" Flag :", modification_flag_loss) print(" Type :", modification_type_loss) print(" Embedding :", embedding_loss) print(" Feature maps :", feature_maps_loss) print(" Mask :", mask_loss) print(" Bits :", bits_loss) optimizer = get_optimizer( optimizer_name, get_optimizable_parameters(model), learning_rate=learning_rate, weight_decay=weight_decay ) scheduler = get_scheduler( scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(loaders["train"]) ) if isinstance(scheduler, CyclicLR): callbacks += [SchedulerCallback(mode="batch")] # model training runner = SupervisedRunner(input_key=required_features, output_key=None) runner.train( fp16=fp16, model=model, criterion=criterions_dict, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, "main"), num_epochs=num_epochs, verbose=verbose, main_metric=main_metric, minimize_metric=main_metric_minimize, checkpoint_data={"cmd_args": vars(args)}, ) del optimizer, loaders, runner, callbacks best_checkpoint = os.path.join(log_dir, "main", "checkpoints", "best.pth") model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}.pth") # Restore state of best model clean_checkpoint(best_checkpoint, model_checkpoint) # unpack_checkpoint(load_checkpoint(model_checkpoint), model=model) torch.cuda.empty_cache() gc.collect() if fine_tune: train_ds, valid_ds, train_sampler = get_datasets( data_dir=data_dir, augmentation="light", balance=balance, fast=fast, fold=fold, features=required_features, obliterate_p=obliterate_p, ) criterions_dict, loss_callbacks = get_criterions( modification_flag=modification_flag_loss, modification_type=modification_type_loss, embedding_loss=embedding_loss, feature_maps_loss=feature_maps_loss, mask_loss=mask_loss, bits_loss=bits_loss, num_epochs=fine_tune, mixup=False, cutmix=False, tsa=False, ) callbacks = ( default_callbacks + loss_callbacks + [ OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False), HyperParametersCallback( hparam_dict={ "model": model_name, "scheduler": scheduler_name, "optimizer": optimizer_name, "augmentations": augmentations, "size": image_size[0], "weight_decay": weight_decay, } ), ] ) loaders = collections.OrderedDict() loaders["train"] = DataLoader( train_ds, batch_size=train_batch_size, num_workers=num_workers, pin_memory=True, drop_last=True, shuffle=train_sampler is None, sampler=train_sampler, ) loaders["valid"] = DataLoader(valid_ds, batch_size=valid_batch_size, num_workers=num_workers, pin_memory=True) print("Train session :", checkpoint_prefix) print(" FP16 mode :", fp16) print(" Fast mode :", args.fast) print(" Epochs :", num_epochs) print(" Workers :", num_workers) print(" Data dir :", data_dir) print(" Log dir :", log_dir) print(" Cache :", cache) print("Data ") print(" Augmentations :", augmentations) print(" Obliterate (%) :", obliterate_p) print(" Negative images:", negative_image_dir) print(" Train size :", len(loaders["train"]), "batches", len(train_ds), "samples") print(" Valid size :", len(loaders["valid"]), "batches", len(valid_ds), "samples") print(" Image size :", image_size) print(" Balance :", balance) print(" Mixup :", mixup) print(" CutMix :", cutmix) print(" TSA :", tsa) print("Model :", model_name) print(" Parameters :", count_parameters(model)) print(" Dropout :", dropout) print("Optimizer :", optimizer_name) print(" Learning rate :", learning_rate) print(" Weight decay :", weight_decay) print(" Scheduler :", scheduler_name) print(" Batch sizes :", train_batch_size, valid_batch_size) print("Losses ") print(" Flag :", modification_flag_loss) print(" Type :", modification_type_loss) print(" Embedding :", embedding_loss) print(" Feature maps :", feature_maps_loss) print(" Mask :", mask_loss) print(" Bits :", bits_loss) optimizer = get_optimizer( "SGD", get_optimizable_parameters(model), learning_rate=learning_rate, weight_decay=weight_decay ) scheduler = get_scheduler( "cos", optimizer, lr=learning_rate, num_epochs=fine_tune, batches_in_epoch=len(loaders["train"]) ) if isinstance(scheduler, CyclicLR): callbacks += [SchedulerCallback(mode="batch")] # model training runner = SupervisedRunner(input_key=required_features, output_key=None) runner.train( fp16=fp16, model=model, criterion=criterions_dict, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, "finetune"), num_epochs=fine_tune, verbose=verbose, main_metric=main_metric, minimize_metric=main_metric_minimize, checkpoint_data={"cmd_args": vars(args)}, ) best_checkpoint = os.path.join(log_dir, "finetune", "checkpoints", "best.pth") model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}_finetune.pth") clean_checkpoint(best_checkpoint, model_checkpoint) unpack_checkpoint(load_checkpoint(model_checkpoint), model=model) del optimizer, loaders, runner, callbacks
def main(args): set_global_seed(42) transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) datasets = { "train": Wrp( CIFAR100(root=".", train=True, download=True, transform=transform_train)), "valid": Wrp(CIFAR100(root=".", train=False, transform=transform_test)), } loaders = { k: DataLoader(v, batch_size=args.batch_size, shuffle=k == "train", num_workers=2) for k, v in datasets.items() } teacher_model = NAME2MODEL[args.teacher](num_classes=100) if args.teacher_path is None: teacher_sd = load_state_dict_from_url(NAME2URL[args.teacher]) teacher_model.load_state_dict(teacher_sd) else: unpack_checkpoint(torch.load(args.teacher_path), model=teacher_model) student_model = NAME2MODEL[args.student](num_classes=100) optimizer = torch.optim.SGD(student_model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [150, 180, 210], gamma=0.1) runner = DistilRunner(apply_probability_shift=args.probability_shift) runner.train(model={ "teacher": teacher_model, "student": student_model }, loaders=loaders, optimizer=optimizer, scheduler=scheduler, valid_metric="accuracy", minimize_valid_metric=False, logdir=args.logdir, callbacks=[ ControlFlowCallback(AttentionHiddenStatesCallback(), loaders="train"), ControlFlowCallback(KLDivCallback(temperature=4), loaders="train"), CriterionCallback(input_key="s_logits", target_key="targets", metric_key="cls_loss"), ControlFlowCallback( MetricAggregationCallback( prefix="loss", metrics={ "attention_loss": args.beta, "kl_div_loss": args.alpha, "cls_loss": 1 - args.alpha, }, mode="weighted_sum", ), loaders="train", ), AccuracyCallback(input_key="s_logits", target_key="targets"), OptimizerCallback(metric_key="loss", model_key="student"), SchedulerCallback(), ], valid_loader="valid", num_epochs=args.num_epochs, criterion=torch.nn.CrossEntropyLoss(), seed=args.seed)