def get_criterions( criterions, criterions_stride2=None, criterions_stride4=None, criterions_stride8=None, criterions_stride16=None, ignore_index=None, ) -> Tuple[List[Callback], Dict]: criterions_dict = {} losses = [] callbacks = [] # Create main losses for loss_name, loss_weight in criterions: criterion_callback = CriterionCallback( prefix=f"{OUTPUT_MASK_KEY}/" + loss_name, input_key=INPUT_MASK_KEY if loss_name != "wbce" else [INPUT_MASK_KEY, INPUT_MASK_WEIGHT_KEY], output_key=OUTPUT_MASK_KEY, criterion_key=f"{OUTPUT_MASK_KEY}/" + loss_name, multiplier=float(loss_weight), ) criterions_dict[criterion_callback.criterion_key] = get_loss(loss_name, ignore_index=ignore_index) callbacks.append(criterion_callback) losses.append(criterion_callback.prefix) print("Using loss", loss_name, loss_weight) # Additional supervision losses for supervision_losses, supervision_output in zip( [criterions_stride2, criterions_stride4, criterions_stride8, criterions_stride16], [OUTPUT_MASK_2_KEY, OUTPUT_MASK_4_KEY, OUTPUT_MASK_8_KEY, OUTPUT_MASK_16_KEY], ): if supervision_losses is not None: for loss_name, loss_weight in supervision_losses: prefix = f"{supervision_output}/" + loss_name criterion_callback = CriterionCallback( prefix=prefix, input_key=INPUT_MASK_KEY if loss_name != "wbce" else [INPUT_MASK_KEY, INPUT_MASK_WEIGHT_KEY], output_key=supervision_output, criterion_key=prefix, multiplier=float(loss_weight), ) criterions_dict[criterion_callback.criterion_key] = ResizeTargetToPrediction2d( get_loss(loss_name, ignore_index=ignore_index) ) callbacks.append(criterion_callback) losses.append(criterion_callback.prefix) print("Using loss", loss_name, loss_weight) callbacks.append(MetricAggregationCallback(prefix="loss", metrics=losses, mode="sum")) return callbacks, criterions_dict
def test_fn_ends_with_pass_on_callbacks(): """@TODO: Docs. Contribution is welcome.""" def test_fn_ends_with_pass_on_callback( callback, events, ): for event in events["covered"]: fn_name = f"on_{event}" assert (utils.fn_ends_with_pass( getattr(callback.__class__, fn_name)) is False) for event in events["non-covered"]: fn_name = f"on_{event}" assert (utils.fn_ends_with_pass( getattr(callback.__class__, fn_name)) is True) # Callback test from catalyst.dl import Callback callback = Callback(order=1) start_events = [ "stage_start", "epoch_start", "batch_start", "loader_start", ] end_events = [ "stage_end", "epoch_end", "batch_end", "loader_end", "exception", ] events = {"covered": [], "non-covered": [*start_events, *end_events]} test_fn_ends_with_pass_on_callback(callback=callback, events=events) # CriterionCallback test from catalyst.dl import CriterionCallback callback = CriterionCallback() covered_events = ["stage_start", "batch_end"] non_covered_start_events = ["epoch_start", "batch_start", "loader_start"] non_covered_end_events = [ "stage_end", "epoch_end", "loader_end", "exception", ] events = { "covered": [*covered_events], "non-covered": [*non_covered_start_events, *non_covered_end_events], } test_fn_ends_with_pass_on_callback(callback=callback, events=events)
def get_criterion_callback(loss_name, input_key, output_key, prefix=None, loss_weight=1.0, ignore_index=UNLABELED_SAMPLE): criterions_dict = { f"{prefix}/{loss_name}": get_loss(loss_name, ignore_index=ignore_index) } if prefix is None: prefix = input_key criterion_callback = CriterionCallback( prefix=f"{prefix}/{loss_name}", input_key=input_key, output_key=output_key, criterion_key=f"{prefix}/{loss_name}", multiplier=float(loss_weight), ) return criterions_dict, criterion_callback, criterion_callback.prefix
def create_callbacks(args, criterion_names): callbacks = [ IoUMetricsCallback(mode=args.dice_mode, input_key=args.input_target_key, class_names=args.class_names.split(',') if args.class_names else None), CheckpointCallback(save_n_best=args.save_n_best), EarlyStoppingCallback( patience=args.patience, metric=args.eval_metric, minimize=True if args.eval_metric == 'loss' else False) ] metrics_weights = {} for cn in criterion_names: callbacks.append( CriterionCallback(input_key=args.input_target_key, prefix=f"loss_{cn}", criterion_key=cn)) metrics_weights[f'loss_{cn}'] = 1.0 callbacks.append( MetricAggregationCallback(prefix="loss", mode="weighted_sum", metrics=metrics_weights)) return callbacks
mixup_alpha = get_dict_value_or_default(config, key='mixup_alpha', default_value=0.3) if mixup: callbacks.extend([ MixupCallback(crit_key='h1', input_key='h1_targets', output_key='h1_logits', alpha=mixup_alpha, on_train_only=False), ]) else: callbacks.extend([ CriterionCallback(input_key="h1_targets", output_key="h1_logits", prefix="loss_h1", criterion_key="h1"), CriterionCallback(input_key="h2_targets", output_key="h2_logits", prefix="loss_h2", criterion_key="h2"), CriterionCallback(input_key="h3_targets", output_key="h3_logits", prefix="loss_h3", criterion_key="h3"), crit_agg, ]) callbacks.extend([ score_callback, EarlyStoppingCallback(metric='weight_recall',
max_lr=0.0016, steps_per_epoch=1, epochs=num_epochs) # scheduler = OneCycleLRWithWarmup( # optimizer, # num_steps=num_epochs, # lr_range=(0.0016, 0.0000001), # init_lr = learning_rate, # warmup_steps=15 # ) loaders = get_loaders(preprocessing_fn, batch_size=8) callbacks = [ # Each criterion is calculated separately. CriterionCallback(input_key="mask", prefix="loss_dice", criterion_key="dice"), CriterionCallback(input_key="mask", prefix="loss_iou", criterion_key="iou"), CriterionCallback(input_key="mask", prefix="loss_bce", criterion_key="bce"), ClasswiseIouCallback(input_key="mask", prefix='clswise_iou', classes=CLASSES.keys()), # And only then we aggregate everything into one loss. MetricAggregationCallback( prefix="loss", mode="weighted_sum", # can be "sum", "weighted_sum" or "mean"
def main(): parser = argparse.ArgumentParser() ########################################################################################### # Distributed-training related stuff parser.add_argument("--local_rank", type=int, default=0) ########################################################################################### parser.add_argument("-acc", "--accumulation-steps", type=int, default=1, help="Number of batches to process") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument("--fast", action="store_true") parser.add_argument( "-dd", "--data-dir", type=str, help="Data directory for INRIA sattelite dataset", default=os.environ.get("INRIA_DATA_DIR"), ) parser.add_argument("-dd-xview2", "--data-dir-xview2", type=str, required=False, help="Data directory for external xView2 dataset") parser.add_argument("-m", "--model", type=str, default="resnet34_fpncat128", help="") parser.add_argument("-b", "--batch-size", type=int, default=8, help="Batch Size during training, e.g. -b 64") parser.add_argument("-e", "--epochs", type=int, default=100, help="Epoch to run") # parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement') # parser.add_argument('-fe', '--freeze-encoder', type=int, default=0, help='Freeze encoder parameters for N epochs') # parser.add_argument('-ft', '--fine-tune', action='store_true') parser.add_argument("-lr", "--learning-rate", type=float, default=1e-3, help="Initial learning rate") parser.add_argument("-l", "--criterion", type=str, required=True, action="append", nargs="+", help="Criterion") parser.add_argument("-o", "--optimizer", default="RAdam", help="Name of the optimizer") parser.add_argument( "-c", "--checkpoint", type=str, default=None, help="Checkpoint filename to use as initial model weights") parser.add_argument("-w", "--workers", default=8, type=int, help="Num workers") parser.add_argument("-a", "--augmentations", default="hard", type=str, help="") parser.add_argument("-tm", "--train-mode", default="random", type=str, help="") parser.add_argument("--run-mode", default="fit_predict", type=str, help="") parser.add_argument("--transfer", default=None, type=str, help="") parser.add_argument("--fp16", action="store_true") parser.add_argument("--size", default=512, type=int) parser.add_argument("-s", "--scheduler", default="multistep", type=str, help="") parser.add_argument("-x", "--experiment", default=None, type=str, help="") parser.add_argument("-d", "--dropout", default=None, type=float, help="Dropout before head layer") parser.add_argument("--opl", action="store_true") parser.add_argument( "--warmup", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters") parser.add_argument("-wd", "--weight-decay", default=0, type=float, help="L2 weight decay") parser.add_argument("--show", action="store_true") parser.add_argument("--dsv", action="store_true") args = parser.parse_args() args.is_master = args.local_rank == 0 args.distributed = False fp16 = args.fp16 if "WORLD_SIZE" in os.environ: args.distributed = int(os.environ["WORLD_SIZE"]) > 1 args.world_size = int(os.environ["WORLD_SIZE"]) # args.world_size = torch.distributed.get_world_size() print("Initializing init_process_group", args.local_rank) torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend="nccl") print("Initialized init_process_group", args.local_rank) is_master = args.is_master | (not args.distributed) distributed_params = {} if args.distributed: distributed_params = {"rank": args.local_rank, "syncbn": True} if args.fp16: distributed_params["apex"] = True distributed_params["opt_level"] = "O1" set_manual_seed(args.seed + args.local_rank) catalyst.utils.set_global_seed(args.seed + args.local_rank) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False data_dir = args.data_dir if data_dir is None: raise ValueError("--data-dir must be set") num_workers = args.workers num_epochs = args.epochs batch_size = args.batch_size learning_rate = args.learning_rate model_name = args.model optimizer_name = args.optimizer image_size = args.size, args.size fast = args.fast augmentations = args.augmentations train_mode = args.train_mode scheduler_name = args.scheduler experiment = args.experiment dropout = args.dropout online_pseudolabeling = args.opl criterions = args.criterion verbose = args.verbose show = args.show use_dsv = args.dsv accumulation_steps = args.accumulation_steps weight_decay = args.weight_decay extra_data_xview2 = args.data_dir_xview2 run_train = num_epochs > 0 need_weight_mask = any(c[0] == "wbce" for c in criterions) custom_model_kwargs = {} if dropout is not None: custom_model_kwargs["dropout"] = float(dropout) model: nn.Module = get_model(model_name, **custom_model_kwargs).cuda() if args.transfer: transfer_checkpoint = fs.auto_file(args.transfer) print("Transfering weights from model checkpoint", transfer_checkpoint) checkpoint = load_checkpoint(transfer_checkpoint) pretrained_dict = checkpoint["model_state_dict"] transfer_weights(model, pretrained_dict) if args.checkpoint: checkpoint = load_checkpoint(fs.auto_file(args.checkpoint)) unpack_checkpoint(checkpoint, model=model) print("Loaded model weights from:", args.checkpoint) report_checkpoint(checkpoint) main_metric = "optimized_jaccard" current_time = datetime.now().strftime("%y%m%d_%H_%M") checkpoint_prefix = f"{current_time}_{args.model}" if fp16: checkpoint_prefix += "_fp16" if fast: checkpoint_prefix += "_fast" if online_pseudolabeling: checkpoint_prefix += "_opl" if extra_data_xview2: checkpoint_prefix += "_with_xview2" if experiment is not None: checkpoint_prefix = experiment if args.distributed: checkpoint_prefix += f"_local_rank_{args.local_rank}" log_dir = os.path.join("runs", checkpoint_prefix) os.makedirs(log_dir, exist_ok=False) config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json") with open(config_fname, "w") as f: train_session_args = vars(args) f.write(json.dumps(train_session_args, indent=2)) default_callbacks = [ PixelAccuracyCallback(input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY), # JaccardMetricPerImage(input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY, prefix="jaccard"), JaccardMetricPerImageWithOptimalThreshold(input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY, prefix="optimized_jaccard") ] if is_master: default_callbacks += [ BestMetricCheckpointCallback(target_metric="optimized_jaccard", target_metric_minimize=False), HyperParametersCallback( hparam_dict={ "model": model_name, "scheduler": scheduler_name, "optimizer": optimizer_name, "augmentations": augmentations, "size": args.size, "weight_decay": weight_decay, "epochs": num_epochs, "dropout": None if dropout is None else float(dropout) }), ] if show and is_master: visualize_inria_predictions = partial( draw_inria_predictions, image_key=INPUT_IMAGE_KEY, image_id_key=INPUT_IMAGE_ID_KEY, targets_key=INPUT_MASK_KEY, outputs_key=OUTPUT_MASK_KEY, max_images=16, ) default_callbacks += [ ShowPolarBatchesCallback(visualize_inria_predictions, metric="accuracy", minimize=False), ShowPolarBatchesCallback(visualize_inria_predictions, metric="loss", minimize=True), ] train_ds, valid_ds, train_sampler = get_datasets( data_dir=data_dir, image_size=image_size, augmentation=augmentations, train_mode=train_mode, buildings_only=(train_mode == "tiles"), fast=fast, need_weight_mask=need_weight_mask, ) if extra_data_xview2 is not None: extra_train_ds, _ = get_xview2_extra_dataset( extra_data_xview2, image_size=image_size, augmentation=augmentations, fast=fast, need_weight_mask=need_weight_mask, ) weights = compute_sample_weight("balanced", [0] * len(train_ds) + [1] * len(extra_train_ds)) train_sampler = WeightedRandomSampler(weights, train_sampler.num_samples * 2) train_ds = train_ds + extra_train_ds print("Using extra data from xView2 with", len(extra_train_ds), "samples") if run_train: loaders = collections.OrderedDict() callbacks = default_callbacks.copy() criterions_dict = {} losses = [] ignore_index = None if online_pseudolabeling: ignore_index = UNLABELED_SAMPLE unlabeled_label = get_pseudolabeling_dataset(data_dir, include_masks=False, augmentation=None, image_size=image_size) unlabeled_train = get_pseudolabeling_dataset( data_dir, include_masks=True, augmentation=augmentations, image_size=image_size) loaders["label"] = DataLoader(unlabeled_label, batch_size=batch_size // 2, num_workers=num_workers, pin_memory=True) if train_sampler is not None: num_samples = 2 * train_sampler.num_samples else: num_samples = 2 * len(train_ds) weights = compute_sample_weight("balanced", [0] * len(train_ds) + [1] * len(unlabeled_label)) train_sampler = WeightedRandomSampler(weights, num_samples, replacement=True) train_ds = train_ds + unlabeled_train callbacks += [ BCEOnlinePseudolabelingCallback2d( unlabeled_train, pseudolabel_loader="label", prob_threshold=0.7, output_key=OUTPUT_MASK_KEY, unlabeled_class=UNLABELED_SAMPLE, label_frequency=5, ) ] print("Using online pseudolabeling with ", len(unlabeled_label), "samples") valid_sampler = None if args.distributed: if train_sampler is not None: train_sampler = DistributedSamplerWrapper(train_sampler, args.world_size, args.local_rank, shuffle=True) else: train_sampler = DistributedSampler(train_ds, args.world_size, args.local_rank, shuffle=True) valid_sampler = DistributedSampler(valid_ds, args.world_size, args.local_rank, shuffle=False) loaders["train"] = DataLoader( train_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True, drop_last=True, shuffle=train_sampler is None, sampler=train_sampler, ) loaders["valid"] = DataLoader(valid_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True, sampler=valid_sampler) # Create losses for loss_name, loss_weight in criterions: criterion_callback = CriterionCallback( prefix=f"{INPUT_MASK_KEY}/" + loss_name, input_key=INPUT_MASK_KEY if loss_name != "wbce" else [INPUT_MASK_KEY, INPUT_MASK_WEIGHT_KEY], output_key=OUTPUT_MASK_KEY, criterion_key=loss_name, multiplier=float(loss_weight), ) criterions_dict[loss_name] = get_loss(loss_name, ignore_index=ignore_index) callbacks.append(criterion_callback) losses.append(criterion_callback.prefix) print("Using loss", loss_name, loss_weight) if use_dsv: print("Using DSV") criterions = "dsv" dsv_loss_name = "soft_bce" criterions_dict[criterions] = AdaptiveMaskLoss2d( get_loss(dsv_loss_name, ignore_index=ignore_index)) for i, dsv_input in enumerate([ OUTPUT_MASK_4_KEY, OUTPUT_MASK_8_KEY, OUTPUT_MASK_16_KEY, OUTPUT_MASK_32_KEY ]): criterion_callback = CriterionCallback( prefix=f"{dsv_input}/" + dsv_loss_name, input_key=INPUT_MASK_KEY, output_key=dsv_input, criterion_key=criterions, multiplier=1.0, ) callbacks.append(criterion_callback) losses.append(criterion_callback.prefix) if isinstance(model, SupervisedHGSegmentationModel): print("Using Hourglass DSV") dsv_loss_name = "kl" criterions_dict["dsv"] = get_loss(dsv_loss_name, ignore_index=ignore_index) num_supervision_inputs = model.encoder.num_blocks - 1 dsv_outputs = [ OUTPUT_MASK_4_KEY + "_after_hg_" + str(i) for i in range(num_supervision_inputs) ] for i, dsv_input in enumerate(dsv_outputs): criterion_callback = CriterionCallback( prefix="supervision/" + dsv_input, input_key=INPUT_MASK_KEY, output_key=dsv_input, criterion_key="dsv", multiplier=(i + 1) / num_supervision_inputs, ) callbacks.append(criterion_callback) losses.append(criterion_callback.prefix) callbacks += [ MetricAggregationCallback(prefix="loss", metrics=losses, mode="sum"), OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False), ] optimizer = get_optimizer(optimizer_name, get_optimizable_parameters(model), learning_rate, weight_decay=weight_decay) scheduler = get_scheduler(scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(loaders["train"])) if isinstance(scheduler, (CyclicLR, OneCycleLRWithWarmup)): callbacks += [SchedulerCallback(mode="batch")] print("Train session :", checkpoint_prefix) print(" FP16 mode :", fp16) print(" Fast mode :", args.fast) print(" Train mode :", train_mode) print(" Epochs :", num_epochs) print(" Workers :", num_workers) print(" Data dir :", data_dir) print(" Log dir :", log_dir) print(" Augmentations :", augmentations) print(" Train size :", "batches", len(loaders["train"]), "dataset", len(train_ds)) print(" Valid size :", "batches", len(loaders["valid"]), "dataset", len(valid_ds)) print("Model :", model_name) print(" Parameters :", count_parameters(model)) print(" Image size :", image_size) print("Optimizer :", optimizer_name) print(" Learning rate :", learning_rate) print(" Batch size :", batch_size) print(" Criterion :", criterions) print(" Use weight mask:", need_weight_mask) if args.distributed: print("Distributed") print(" World size :", args.world_size) print(" Local rank :", args.local_rank) print(" Is master :", args.is_master) # model training runner = SupervisedRunner(input_key=INPUT_IMAGE_KEY, output_key=None, device="cuda") runner.train( fp16=distributed_params, model=model, criterion=criterions_dict, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, "main"), num_epochs=num_epochs, verbose=verbose, main_metric=main_metric, minimize_metric=False, checkpoint_data={"cmd_args": vars(args)}, ) # Training is finished. Let's run predictions using best checkpoint weights best_checkpoint = os.path.join(log_dir, "main", "checkpoints", "best.pth") model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}.pth") clean_checkpoint(best_checkpoint, model_checkpoint) unpack_checkpoint(torch.load(model_checkpoint), model=model) mask = predict(model, read_inria_image("sample_color.jpg"), image_size=image_size, batch_size=args.batch_size) mask = ((mask > 0) * 255).astype(np.uint8) name = os.path.join(log_dir, "sample_color.jpg") cv2.imwrite(name, mask) del optimizer, loaders
runner = SupervisedRunner(input_key=["seg_features"], output_key=["cls_logits", "seg_logits"]) def calc_metric(pred, gt, *args, **kwargs): pred = torch.sigmoid(pred).detach().cpu().numpy() gt = gt.detach().cpu().numpy().astype(np.uint8) try: return [roc_auc_score(gt.reshape(-1), pred.reshape(-1))] except: return [0] callbacks = [ CriterionCallback(input_key="cls_targets", output_key="cls_logits", prefix="loss_cls", criterion_key="cls"), CriterionCallback(input_key="seg_targets", output_key="seg_logits", prefix="loss_seg", criterion_key="seg"), CriterionAggregatorCallback( prefix="loss", loss_keys=["loss_cls", "loss_seg"], loss_aggregate_fn="sum" # or "mean" ), MultiMetricCallback(metric_fn=calc_metric, prefix='rocauc', input_key="cls_targets", output_key="cls_logits", list_args=['_']),
input_target_key=None) optimizer = RAdam(model.parameters(), lr=args.lr, weight_decay=0.001) scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.75, patience=3, mode="max") criterion = {'label_loss': nn.CrossEntropyLoss()} callbacks = [ CriterionCallback( input_key="label", output_key="logit_label", prefix="label_loss", criterion_key="label_loss", multiplier=1.0, ), MetricAggregationCallback( prefix="loss", metrics=[ "label_loss", ], ), WeightedAUC(input_key="label", output_key="logit_label") ] runner.train(model=model, criterion=criterion, optimizer=optimizer,
def main(): parser = argparse.ArgumentParser() parser.add_argument("-acc", "--accumulation-steps", type=int, default=1, help="Number of batches to process") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument("--fast", action="store_true") parser.add_argument("-dd", "--data-dir", type=str, required=True, help="Data directory for INRIA sattelite dataset") parser.add_argument("-dd-xview2", "--data-dir-xview2", type=str, required=False, help="Data directory for external xView2 dataset") parser.add_argument("-m", "--model", type=str, default="resnet34_fpncat128", help="") parser.add_argument("-b", "--batch-size", type=int, default=8, help="Batch Size during training, e.g. -b 64") parser.add_argument("-e", "--epochs", type=int, default=100, help="Epoch to run") # parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement') # parser.add_argument('-fe', '--freeze-encoder', type=int, default=0, help='Freeze encoder parameters for N epochs') # parser.add_argument('-ft', '--fine-tune', action='store_true') parser.add_argument("-lr", "--learning-rate", type=float, default=1e-3, help="Initial learning rate") parser.add_argument("-l", "--criterion", type=str, required=True, action="append", nargs="+", help="Criterion") parser.add_argument("-o", "--optimizer", default="RAdam", help="Name of the optimizer") parser.add_argument( "-c", "--checkpoint", type=str, default=None, help="Checkpoint filename to use as initial model weights") parser.add_argument("-w", "--workers", default=8, type=int, help="Num workers") parser.add_argument("-a", "--augmentations", default="hard", type=str, help="") parser.add_argument("-tm", "--train-mode", default="random", type=str, help="") parser.add_argument("--run-mode", default="fit_predict", type=str, help="") parser.add_argument("--transfer", default=None, type=str, help="") parser.add_argument("--fp16", action="store_true") parser.add_argument("--size", default=512, type=int) parser.add_argument("-s", "--scheduler", default="multistep", type=str, help="") parser.add_argument("-x", "--experiment", default=None, type=str, help="") parser.add_argument("-d", "--dropout", default=0.0, type=float, help="Dropout before head layer") parser.add_argument("--opl", action="store_true") parser.add_argument( "--warmup", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters") parser.add_argument("-wd", "--weight-decay", default=0, type=float, help="L2 weight decay") parser.add_argument("--show", action="store_true") parser.add_argument("--dsv", action="store_true") args = parser.parse_args() set_manual_seed(args.seed) data_dir = args.data_dir num_workers = args.workers num_epochs = args.epochs batch_size = args.batch_size learning_rate = args.learning_rate model_name = args.model optimizer_name = args.optimizer image_size = args.size, args.size fast = args.fast augmentations = args.augmentations train_mode = args.train_mode fp16 = args.fp16 scheduler_name = args.scheduler experiment = args.experiment dropout = args.dropout online_pseudolabeling = args.opl criterions = args.criterion verbose = args.verbose warmup = args.warmup show = args.show use_dsv = args.dsv accumulation_steps = args.accumulation_steps weight_decay = args.weight_decay extra_data_xview2 = args.data_dir_xview2 run_train = num_epochs > 0 need_weight_mask = any(c[0] == "wbce" for c in criterions) model: nn.Module = get_model(model_name, dropout=dropout).cuda() if args.transfer: transfer_checkpoint = fs.auto_file(args.transfer) print("Transfering weights from model checkpoint", transfer_checkpoint) checkpoint = load_checkpoint(transfer_checkpoint) pretrained_dict = checkpoint["model_state_dict"] transfer_weights(model, pretrained_dict) if args.checkpoint: checkpoint = load_checkpoint(fs.auto_file(args.checkpoint)) unpack_checkpoint(checkpoint, model=model) print("Loaded model weights from:", args.checkpoint) report_checkpoint(checkpoint) runner = SupervisedRunner(input_key=INPUT_IMAGE_KEY, output_key=None, device="cuda") main_metric = "optimized_jaccard" cmd_args = vars(args) current_time = datetime.now().strftime("%b%d_%H_%M") checkpoint_prefix = f"{current_time}_{args.model}" if fp16: checkpoint_prefix += "_fp16" if fast: checkpoint_prefix += "_fast" if online_pseudolabeling: checkpoint_prefix += "_opl" if extra_data_xview2: checkpoint_prefix += "_with_xview2" if experiment is not None: checkpoint_prefix = experiment log_dir = os.path.join("runs", checkpoint_prefix) os.makedirs(log_dir, exist_ok=False) config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json") with open(config_fname, "w") as f: train_session_args = vars(args) f.write(json.dumps(train_session_args, indent=2)) default_callbacks = [ PixelAccuracyCallback(input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY), JaccardMetricPerImage(input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY, prefix="jaccard"), OptimalThreshold(input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY, prefix="optimized_jaccard"), # OutputDistributionCallback(output_key=OUTPUT_MASK_KEY, activation=torch.sigmoid), ] if show: visualize_inria_predictions = partial( draw_inria_predictions, image_key=INPUT_IMAGE_KEY, image_id_key=INPUT_IMAGE_ID_KEY, targets_key=INPUT_MASK_KEY, outputs_key=OUTPUT_MASK_KEY, ) default_callbacks += [ ShowPolarBatchesCallback(visualize_inria_predictions, metric="accuracy", minimize=False) ] train_ds, valid_ds, train_sampler = get_datasets( data_dir=data_dir, image_size=image_size, augmentation=augmentations, train_mode=train_mode, fast=fast, need_weight_mask=need_weight_mask, ) if extra_data_xview2 is not None: extra_train_ds, _ = get_xview2_extra_dataset( extra_data_xview2, image_size=image_size, augmentation=augmentations, fast=fast, need_weight_mask=need_weight_mask, ) weights = compute_sample_weight("balanced", [0] * len(train_ds) + [1] * len(extra_train_ds)) train_sampler = WeightedRandomSampler(weights, train_sampler.num_samples * 2) train_ds = train_ds + extra_train_ds print("Using extra data from xView2 with", len(extra_train_ds), "samples") # Pretrain/warmup if warmup: callbacks = default_callbacks.copy() criterions_dict = {} losses = [] ignore_index = None for loss_name, loss_weight in criterions: criterion_callback = CriterionCallback( prefix="seg_loss/" + loss_name, input_key=INPUT_MASK_KEY if loss_name != "wbce" else [INPUT_MASK_KEY, INPUT_MASK_WEIGHT_KEY], output_key=OUTPUT_MASK_KEY, criterion_key=loss_name, multiplier=float(loss_weight), ) criterions_dict[loss_name] = get_loss(loss_name, ignore_index=ignore_index) callbacks.append(criterion_callback) losses.append(criterion_callback.prefix) print("Using loss", loss_name, loss_weight) callbacks += [ CriterionAggregatorCallback(prefix="loss", loss_keys=losses), OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False), ] parameters = get_lr_decay_parameters(model.named_parameters(), learning_rate, {"encoder": 0.1}) optimizer = get_optimizer("RAdam", parameters, learning_rate=learning_rate * 0.1) loaders = collections.OrderedDict() loaders["train"] = DataLoader( train_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True, drop_last=True, shuffle=train_sampler is None, sampler=train_sampler, ) loaders["valid"] = DataLoader(valid_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True, shuffle=False, drop_last=False) runner.train( fp16=fp16, model=model, criterion=criterions_dict, optimizer=optimizer, scheduler=None, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, "warmup"), num_epochs=warmup, verbose=verbose, main_metric=main_metric, minimize_metric=False, checkpoint_data={"cmd_args": cmd_args}, ) del optimizer, loaders best_checkpoint = os.path.join(log_dir, "warmup", "checkpoints", "best.pth") model_checkpoint = os.path.join(log_dir, "warmup", "checkpoints", f"{checkpoint_prefix}_warmup.pth") clean_checkpoint(best_checkpoint, model_checkpoint) torch.cuda.empty_cache() gc.collect() if run_train: loaders = collections.OrderedDict() callbacks = default_callbacks.copy() criterions_dict = {} losses = [] ignore_index = None if online_pseudolabeling: ignore_index = UNLABELED_SAMPLE unlabeled_label = get_pseudolabeling_dataset(data_dir, include_masks=False, augmentation=None, image_size=image_size) unlabeled_train = get_pseudolabeling_dataset( data_dir, include_masks=True, augmentation=augmentations, image_size=image_size) loaders["label"] = DataLoader(unlabeled_label, batch_size=batch_size // 2, num_workers=num_workers, pin_memory=True) if train_sampler is not None: num_samples = 2 * train_sampler.num_samples else: num_samples = 2 * len(train_ds) weights = compute_sample_weight("balanced", [0] * len(train_ds) + [1] * len(unlabeled_label)) train_sampler = WeightedRandomSampler(weights, num_samples, replacement=True) train_ds = train_ds + unlabeled_train callbacks += [ BCEOnlinePseudolabelingCallback2d( unlabeled_train, pseudolabel_loader="label", prob_threshold=0.7, output_key=OUTPUT_MASK_KEY, unlabeled_class=UNLABELED_SAMPLE, label_frequency=5, ) ] print("Using online pseudolabeling with ", len(unlabeled_label), "samples") loaders["train"] = DataLoader( train_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True, drop_last=True, shuffle=train_sampler is None, sampler=train_sampler, ) loaders["valid"] = DataLoader(valid_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True) # Create losses for loss_name, loss_weight in criterions: criterion_callback = CriterionCallback( prefix="seg_loss/" + loss_name, input_key=INPUT_MASK_KEY if loss_name != "wbce" else [INPUT_MASK_KEY, INPUT_MASK_WEIGHT_KEY], output_key=OUTPUT_MASK_KEY, criterion_key=loss_name, multiplier=float(loss_weight), ) criterions_dict[loss_name] = get_loss(loss_name, ignore_index=ignore_index) callbacks.append(criterion_callback) losses.append(criterion_callback.prefix) print("Using loss", loss_name, loss_weight) if use_dsv: print("Using DSV") criterions = "dsv" dsv_loss_name = "soft_bce" criterions_dict[criterions] = AdaptiveMaskLoss2d( get_loss(dsv_loss_name, ignore_index=ignore_index)) for i, dsv_input in enumerate([ OUTPUT_MASK_4_KEY, OUTPUT_MASK_8_KEY, OUTPUT_MASK_16_KEY, OUTPUT_MASK_32_KEY ]): criterion_callback = CriterionCallback( prefix="seg_loss_dsv/" + dsv_input, input_key=OUTPUT_MASK_KEY, output_key=dsv_input, criterion_key=criterions, multiplier=1.0, ) callbacks.append(criterion_callback) losses.append(criterion_callback.prefix) callbacks += [ CriterionAggregatorCallback(prefix="loss", loss_keys=losses), OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False), ] optimizer = get_optimizer(optimizer_name, get_optimizable_parameters(model), learning_rate, weight_decay=weight_decay) scheduler = get_scheduler(scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(loaders["train"])) if isinstance(scheduler, (CyclicLR, OneCycleLRWithWarmup)): callbacks += [SchedulerCallback(mode="batch")] print("Train session :", checkpoint_prefix) print("\tFP16 mode :", fp16) print("\tFast mode :", args.fast) print("\tTrain mode :", train_mode) print("\tEpochs :", num_epochs) print("\tWorkers :", num_workers) print("\tData dir :", data_dir) print("\tLog dir :", log_dir) print("\tAugmentations :", augmentations) print("\tTrain size :", len(loaders["train"]), len(train_ds)) print("\tValid size :", len(loaders["valid"]), len(valid_ds)) print("Model :", model_name) print("\tParameters :", count_parameters(model)) print("\tImage size :", image_size) print("Optimizer :", optimizer_name) print("\tLearning rate :", learning_rate) print("\tBatch size :", batch_size) print("\tCriterion :", criterions) print("\tUse weight mask:", need_weight_mask) # model training runner.train( fp16=fp16, model=model, criterion=criterions_dict, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, "main"), num_epochs=num_epochs, verbose=verbose, main_metric=main_metric, minimize_metric=False, checkpoint_data={"cmd_args": vars(args)}, ) # Training is finished. Let's run predictions using best checkpoint weights best_checkpoint = os.path.join(log_dir, "main", "checkpoints", "best.pth") model_checkpoint = os.path.join(log_dir, "main", "checkpoints", f"{checkpoint_prefix}.pth") clean_checkpoint(best_checkpoint, model_checkpoint) unpack_checkpoint(torch.load(model_checkpoint), model=model) mask = predict(model, read_inria_image("sample_color.jpg"), image_size=image_size, batch_size=args.batch_size) mask = ((mask > 0) * 255).astype(np.uint8) name = os.path.join(log_dir, "sample_color.jpg") cv2.imwrite(name, mask) del optimizer, loaders
def main(cfg: DictConfig): cwd = Path(get_original_cwd()) # overwrite config if continue training from checkpoint resume_cfg = None if "resume" in cfg: cfg_path = cwd / cfg.resume / ".hydra/config.yaml" print(f"Continue from: {cfg.resume}") # Overwrite everything except device # TODO config merger (perhaps continue training with the same optimizer but other lrs?) resume_cfg = OmegaConf.load(cfg_path) cfg.model = resume_cfg.model if cfg.train.num_epochs == 0: cfg.data.scale_factor = resume_cfg.data.scale_factor OmegaConf.save(cfg, ".hydra/config.yaml") print(OmegaConf.to_yaml(cfg)) device = set_device_id(cfg.device) set_seed(cfg.seed, device=device) # Augmentations if cfg.data.aug == "auto": transforms = albu.load(cwd / "autoalbument/autoconfig.json") else: transforms = D.get_training_augmentations() if OmegaConf.is_missing(cfg.model, "convert_bottleneck"): cfg.model.convert_bottleneck = (0, 0, 0) # Model print(f"Setup model {cfg.model.arch} {cfg.model.encoder_name} " f"convert_bn={cfg.model.convert_bn} " f"convert_bottleneck={cfg.model.convert_bottleneck} ") model = get_segmentation_model( arch=cfg.model.arch, encoder_name=cfg.model.encoder_name, encoder_weights=cfg.model.encoder_weights, classes=1, convert_bn=cfg.model.convert_bn, convert_bottleneck=cfg.model.convert_bottleneck, # decoder_attention_type="scse", # TODO to config ) model = model.to(device) model.train() print(model) # Optimization # Reduce LR for pretrained encoder layerwise_params = { "encoder*": dict(lr=cfg.optim.lr_encoder, weight_decay=cfg.optim.wd_encoder) } model_params = cutils.process_model_params( model, layerwise_params=layerwise_params) # Select optimizer optimizer = get_optimizer( name=cfg.optim.name, model_params=model_params, lr=cfg.optim.lr, wd=cfg.optim.wd, lookahead=cfg.optim.lookahead, ) criterion = { "dice": DiceLoss(), # "dice": SoftDiceLoss(mode="binary", smooth=1e-7), "iou": IoULoss(), "bce": nn.BCEWithLogitsLoss(), "lovasz": LovaszLossBinary(), "focal_tversky": FocalTverskyLoss(eps=1e-7, alpha=0.7, gamma=0.75), } # Load states if resuming training if "resume" in cfg: checkpoint_path = (cwd / cfg.resume / cfg.train.logdir / "checkpoints/best_full.pth") if checkpoint_path.exists(): print(f"\nLoading checkpoint {str(checkpoint_path)}") checkpoint = cutils.load_checkpoint(checkpoint_path) cutils.unpack_checkpoint( checkpoint=checkpoint, model=model, optimizer=optimizer if resume_cfg.optim.name == cfg.optim.name else None, criterion=criterion, ) else: raise ValueError("Nothing to resume, checkpoint missing") # We could only want to validate resume, in this case skip training routine best_th = 0.5 stats = None if cfg.data.stats: print(f"Use statistics from file: {cfg.data.stats}") stats = cwd / cfg.data.stats if cfg.train.num_epochs is not None: callbacks = [ # Each criterion is calculated separately. CriterionCallback(input_key="mask", prefix="loss_dice", criterion_key="dice"), CriterionCallback(input_key="mask", prefix="loss_iou", criterion_key="iou"), CriterionCallback(input_key="mask", prefix="loss_bce", criterion_key="bce"), CriterionCallback(input_key="mask", prefix="loss_lovasz", criterion_key="lovasz"), CriterionCallback( input_key="mask", prefix="loss_focal_tversky", criterion_key="focal_tversky", ), # And only then we aggregate everything into one loss. MetricAggregationCallback( prefix="loss", mode="weighted_sum", # can be "sum", "weighted_sum" or "mean" # because we want weighted sum, we need to add scale for each loss metrics={ "loss_dice": cfg.loss.dice, "loss_iou": cfg.loss.iou, "loss_bce": cfg.loss.bce, "loss_lovasz": cfg.loss.lovasz, "loss_focal_tversky": cfg.loss.focal_tversky, }, ), # metrics DiceCallback(input_key="mask"), IouCallback(input_key="mask"), # gradient accumulation OptimizerCallback(accumulation_steps=cfg.optim.accumulate), # early stopping SchedulerCallback(reduced_metric="loss_dice", mode=cfg.scheduler.mode), EarlyStoppingCallback(**cfg.scheduler.early_stopping, minimize=False), # TODO WandbLogger works poorly with multistage right now WandbLogger(project=cfg.project, config=dict(cfg)), # CheckpointCallback(save_n_best=cfg.checkpoint.save_n_best), ] # Training runner = SupervisedRunner(device=device, input_key="image", input_target_key="mask") # TODO Scheduler does not work now, every stage restarts from base lr scheduler_warm_restart = optim.lr_scheduler.MultiStepLR( optimizer, milestones=[1, 2], gamma=10, ) for i, (size, num_epochs) in enumerate( zip(cfg.data.sizes, cfg.train.num_epochs)): scale = size / 1024 print( f"Training stage {i}, scale {scale}, size {size}, epochs {num_epochs}" ) # Datasets ( train_ds, valid_ds, train_images, val_images, ) = D.get_train_valid_datasets_from_path( # path=(cwd / cfg.data.path), path=(cwd / f"data/hubmap-{size}x{size}/"), train_ids=cfg.data.train_ids, valid_ids=cfg.data.valid_ids, seed=cfg.seed, valid_split=cfg.data.valid_split, mean=cfg.data.mean, std=cfg.data.std, transforms=transforms, stats=stats, ) train_bs = int(cfg.loader.train_bs / (scale**2)) valid_bs = int(cfg.loader.valid_bs / (scale**2)) print( f"train: {len(train_ds)}; bs {train_bs}", f"valid: {len(valid_ds)}, bs {valid_bs}", ) # Data loaders data_loaders = D.get_data_loaders( train_ds=train_ds, valid_ds=valid_ds, train_bs=train_bs, valid_bs=valid_bs, num_workers=cfg.loader.num_workers, ) # Select scheduler scheduler = get_scheduler( name=cfg.scheduler.type, optimizer=optimizer, num_epochs=num_epochs * (len(data_loaders["train"]) if cfg.scheduler.mode == "batch" else 1), eta_min=scheduler_warm_restart.get_last_lr()[0] / cfg.scheduler.eta_min_factor, plateau=cfg.scheduler.plateau, ) runner.train( model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, logdir=cfg.train.logdir, loaders=data_loaders, num_epochs=num_epochs, verbose=True, main_metric=cfg.train.main_metric, load_best_on_end=True, minimize_metric=False, check=cfg.check, fp16=dict(amp=cfg.amp), ) # Set new initial LR for optimizer after restart scheduler_warm_restart.step() print( f"New LR for warm restart {scheduler_warm_restart.get_last_lr()[0]}" ) # Find optimal threshold for dice score model.eval() best_th, dices = find_dice_threshold(model, data_loaders["valid"]) print("Best dice threshold", best_th, np.max(dices[1])) np.save(f"dices_{size}.npy", dices) else: print("Validation only") # Datasets size = cfg.data.sizes[-1] train_ds, valid_ds = D.get_train_valid_datasets_from_path( # path=(cwd / cfg.data.path), path=(cwd / f"data/hubmap-{size}x{size}/"), train_ids=cfg.data.train_ids, valid_ids=cfg.data.valid_ids, seed=cfg.seed, valid_split=cfg.data.valid_split, mean=cfg.data.mean, std=cfg.data.std, transforms=transforms, stats=stats, ) train_bs = int(cfg.loader.train_bs / (cfg.data.scale_factor**2)) valid_bs = int(cfg.loader.valid_bs / (cfg.data.scale_factor**2)) print( f"train: {len(train_ds)}; bs {train_bs}", f"valid: {len(valid_ds)}, bs {valid_bs}", ) # Data loaders data_loaders = D.get_data_loaders( train_ds=train_ds, valid_ds=valid_ds, train_bs=train_bs, valid_bs=valid_bs, num_workers=cfg.loader.num_workers, ) # Find optimal threshold for dice score model.eval() best_th, dices = find_dice_threshold(model, data_loaders["valid"]) print("Best dice threshold", best_th, np.max(dices[1])) np.save(f"dices_val.npy", dices) # # # Load best checkpoint # checkpoint_path = Path(cfg.train.logdir) / "checkpoints/best.pth" # if checkpoint_path.exists(): # print(f"\nLoading checkpoint {str(checkpoint_path)}") # state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))[ # "model_state_dict" # ] # model.load_state_dict(state_dict) # del state_dict # model = model.to(device) # Load config for updating with threshold and metric # (otherwise loading do not work) cfg = OmegaConf.load(".hydra/config.yaml") cfg.threshold = float(best_th) # Evaluate on full-size image if valid_ids is non-empty df_train = pd.read_csv(cwd / "data/train.csv") df_train = { r["id"]: r["encoding"] for r in df_train.to_dict(orient="record") } dices = [] unique_ids = sorted( set( str(p).split("/")[-1].split("_")[0] for p in (cwd / cfg.data.path / "train").iterdir())) size = cfg.data.sizes[-1] scale = size / 1024 for image_id in cfg.data.valid_ids: image_name = unique_ids[image_id] print(f"\nValidate for {image_name}") rle_pred, shape = inference_one( image_path=(cwd / f"data/train/{image_name}.tiff"), target_path=Path("."), cfg=cfg, model=model, scale_factor=scale, tile_size=cfg.data.tile_size, tile_step=cfg.data.tile_step, threshold=best_th, save_raw=False, tta_mode=None, weight="pyramid", device=device, filter_crops="tissue", stats=stats, ) print("Predict", shape) pred = rle_decode(rle_pred["predicted"], shape) mask = rle_decode(df_train[image_name], shape) assert pred.shape == mask.shape, f"pred {pred.shape}, mask {mask.shape}" assert pred.shape == shape, f"pred {pred.shape}, expected {shape}" dices.append( dice( torch.from_numpy(pred).type(torch.uint8), torch.from_numpy(mask).type(torch.uint8), threshold=None, activation="none", )) print("Full image dice:", np.mean(dices)) OmegaConf.save(cfg, ".hydra/config.yaml") return
def train_segmentation_model( model: torch.nn.Module, logdir: str, num_epochs: int, loaders: Dict[str, DataLoader] ): criterion = { "dice": DiceLoss(), "iou": IoULoss(), "bce": nn.BCEWithLogitsLoss() } learning_rate = 0.001 encoder_learning_rate = 0.0005 layerwise_params = {"encoder*": dict(lr=encoder_learning_rate, weight_decay=0.00003)} model_params = utils.process_model_params(model, layerwise_params=layerwise_params) base_optimizer = RAdam(model_params, lr=learning_rate, weight_decay=0.0003) optimizer = Lookahead(base_optimizer) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.25, patience=2) device = utils.get_device() runner = SupervisedRunner(device=device, input_key='image', input_target_key='mask') callbacks = [ CriterionCallback( input_key="mask", prefix="loss_dice", criterion_key="dice" ), CriterionCallback( input_key="mask", prefix="loss_iou", criterion_key="iou" ), CriterionCallback( input_key="mask", prefix="loss_bce", criterion_key="bce" ), MetricAggregationCallback( prefix="loss", mode="weighted_sum", metrics={"loss_dice": 1.0, "loss_iou": 1.0, "loss_bce": 0.8}, ), # metrics DiceCallback(input_key='mask'), IouCallback(input_key='mask'), ] runner.train( model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, loaders=loaders, callbacks=callbacks, logdir=logdir, num_epochs=num_epochs, main_metric="iou", minimize_metric=False, verbose=True, load_best_on_end=True, ) best_model_save_dir = os.path.join(logdir, 'save') os.makedirs(best_model_save_dir, True) torch.save(model, os.path.join(best_model_save_dir, 'best_model.pth')) # save best model (by valid loss) batch = next(iter(loaders["valid"])) try: runner.trace(model=model, batch=batch, logdir=logdir, fp16=False) # optimized version (not all models can be traced) except Exception: pass
"logit_grapheme", ), input_key=( "grapheme_root", "vowel_diacritic", "consonant_diacritic", "grapheme", ), mixuponly=False, alpha=0.5, resolution=IMG_SIZE, ), CriterionCallback( input_key="grapheme_root", output_key="logit_grapheme_root", prefix="grapheme_root_loss", criterion_key="grapheme_root_loss", multiplier=2.0, ), CriterionCallback( input_key="vowel_diacritic", output_key="logit_vowel_diacritic", prefix="vowel_diacritic_loss", criterion_key="vowel_diacritic_loss", multiplier=1.0, ), CriterionCallback( input_key="consonant_diacritic", output_key="logit_consonant_diacritic", prefix="consonant_diacritic_loss", criterion_key="consonant_diacritic_loss",
def get_cls_callbacks(loss_name, num_classes, num_epochs, class_names, tsa=None, uda=None, show=False): if len(loss_name) == 1: loss_name, multiplier = loss_name[0], 1.0 elif len(loss_name) == 2: loss_name, multiplier = loss_name[0], float(loss_name[1]) else: raise ValueError(loss_name) criterions = {'cls': get_loss(loss_name, ignore_index=UNLABELED_CLASS)} output_key = 'logits' if tsa: crit_callback = TSACriterionCallback(prefix='cls/tsa_loss', loss_key='cls', output_key=output_key, criterion_key='cls', multiplier=multiplier, num_classes=num_classes, num_epochs=num_epochs) else: crit_callback = CriterionCallback(prefix='cls/loss', loss_key='cls', output_key=output_key, criterion_key='cls', multiplier=multiplier) callbacks = [ crit_callback, CappaScoreCallback(prefix='cls/kappa', output_key=output_key, ignore_index=UNLABELED_CLASS, class_names=class_names), # Metrics CustomAccuracyCallback(output_key=output_key, prefix='cls/accuracy', ignore_index=UNLABELED_CLASS), # F1 scores FScoreCallback(prefix='cls/f1_macro', beta=1, average='macro', output_key=output_key, ignore_index=UNLABELED_CLASS), FScoreCallback(prefix='cls/f1_micro', beta=2, average='micro', output_key=output_key, ignore_index=UNLABELED_CLASS), # F2 scores FScoreCallback(prefix='cls/f2_macro', beta=2, average='macro', output_key=output_key, ignore_index=UNLABELED_CLASS), FScoreCallback(prefix='cls/f2_micro', beta=2, average='micro', output_key=output_key, ignore_index=UNLABELED_CLASS) ] if uda: callbacks += [ UDACriterionCallback(prefix='cls/uda', output_key=output_key, unsupervised_label=UNLABELED_CLASS) ] else: callbacks += [ ConfusionMatrixCallback(prefix='cls/confusion', output_key=output_key, class_names=class_names), NegativeMiningCallback(ignore_index=UNLABELED_CLASS), ] if show: visualization_fn = partial(draw_classification_predictions, class_names=class_names) callbacks += [ ShowPolarBatchesCallback(visualization_fn, metric='cls/accuracy', minimize=False) ] return callbacks, criterions
def get_reg_callbacks(loss_name, class_names, prefix='reg', output_key='regression', uda=None, show=False): if len(loss_name) == 1: loss_name, multiplier = loss_name[0], 1.0 elif len(loss_name) == 2: loss_name, multiplier = loss_name[0], float(loss_name[1]) else: raise ValueError(loss_name) criterions = {prefix: get_loss(loss_name, ignore_index=UNLABELED_CLASS)} callbacks = [ # Loss CriterionCallback(prefix=f'{prefix}/loss', loss_key=prefix, output_key=output_key, criterion_key=prefix, multiplier=multiplier), # Metrics RMSEMetric(prefix=f'{prefix}/rmse', output_key=output_key), CappaScoreCallback(prefix=f'{prefix}/kappa', output_key=output_key, ignore_index=UNLABELED_CLASS, class_names=class_names, optimize_thresholds=False, from_regression=True), CustomAccuracyCallback(prefix=f'{prefix}/accuracy', output_key=output_key, from_regression=True, ignore_index=UNLABELED_CLASS), ConfusionMatrixCallbackFromRegression(prefix=f'{prefix}/confusion', output_key=output_key, class_names=class_names, ignore_index=UNLABELED_CLASS), # F1 scores FScoreCallback(prefix=f'{prefix}/f1_macro', beta=1, average='macro', output_key=output_key, from_regression=True, ignore_index=UNLABELED_CLASS), FScoreCallback(prefix=f'{prefix}/f1_micro', beta=2, average='micro', output_key=output_key, from_regression=True, ignore_index=UNLABELED_CLASS), # F2 scores FScoreCallback(prefix=f'{prefix}/f2_macro', beta=2, average='macro', output_key=output_key, from_regression=True, ignore_index=UNLABELED_CLASS), FScoreCallback(prefix=f'{prefix}/f2_micro', beta=2, average='micro', output_key=output_key, from_regression=True, ignore_index=UNLABELED_CLASS) ] if uda: callbacks += [ UDARegressionCriterionCallback(prefix=f'{prefix}/uda', output_key=output_key, unsupervised_label=UNLABELED_CLASS) ] if show: visualization_fn = partial(draw_regression_predictions, outputs_key=output_key, class_names=class_names, unsupervised_label=UNLABELED_CLASS) callbacks += [ ShowPolarBatchesCallback(visualization_fn, metric=f'{prefix}/accuracy', minimize=False) ] return callbacks, criterions