def log_images(cfg: Config, image_batch, name, step, nsamples=64, nrows=8, monochrome=False, prefix=None): """Make a grid of the given images, save them in a file and log them with W&B""" prefix = "train_" if prefix is None else f"{prefix}_" images = image_batch[:nsamples] if cfg.enc.recon_loss == RL.ce: images = images.argmax(dim=1).float() / 255 else: if cfg.data.dataset in (DS.celeba, DS.genfaces): images = 0.5 * images + 0.5 if monochrome: images = images.mean(dim=1, keepdim=True) # torchvision.utils.save_image(images, f'./experiments/finn/{prefix}{name}.png', nrow=nrows) shw = torchvision.utils.make_grid(images, nrow=nrows).clamp(0, 1).cpu() wandb_log( cfg.misc, { prefix + name: [wandb.Image(torchvision.transforms.functional.to_pil_image(shw))] }, step=step, )
def fit(self, train_data: DataLoader, epochs: int, device: torch.device, use_wandb: bool) -> None: self.train() step = 0 logging_dict = {} # enc_sched = torch.optim.lr_scheduler.StepLR(self.encoder.optimizer, step_size=9, gamma=.3) # dec_sched = torch.optim.lr_scheduler.StepLR(self.decoder.optimizer, step_size=9, gamma=.3) with tqdm(total=epochs * len(train_data)) as pbar: for _ in range(epochs): for x, _, _ in train_data: x = x.to(device) self.zero_grad() _, loss, logging_dict = self.routine(x) loss.backward() self.step() enc_loss: float = loss.item() pbar.update() pbar.set_postfix(AE_loss=enc_loss) if use_wandb: step += 1 logging_dict.update({"Total Loss": enc_loss}) wandb_log(True, logging_dict, step) # enc_sched.step() # dec_sched.step() log.info("Final result from encoder training:") print_metrics({f"Enc {k}": v for k, v in logging_dict.items()})
def train_step( components: Union["AeComponents", InnComponents], context_data_itr: Iterator[Tuple[Tensor, Tensor, Tensor]], train_data_itr: Iterator[Tuple[Tensor, Tensor, Tensor]], itr: int, ) -> Dict[str, float]: disc_weight = 0.0 if itr < ARGS.warmup_steps else ARGS.disc_weight if ARGS.disc_method == DM.nn: # Train the discriminator on its own for a number of iterations for _ in range(ARGS.num_disc_updates): x_c, x_t, s_t, y_t = get_batch(context_data_itr=context_data_itr, train_data_itr=train_data_itr) if components.type_ == "ae": _, disc_logging = update_disc(x_c, x_t, components, itr < ARGS.warmup_steps) else: update_disc_on_inn(ARGS, x_c, x_t, components, itr < ARGS.warmup_steps) x_c, x_t, s_t, y_t = get_batch(context_data_itr=context_data_itr, train_data_itr=train_data_itr) if components.type_ == "ae": _, logging_dict = update(x_c=x_c, x_t=x_t, s_t=s_t, y_t=y_t, ae=components, warmup=itr < ARGS.warmup_steps) else: _, logging_dict = update_inn(args=ARGS, x_c=x_c, x_t=x_t, models=components, disc_weight=disc_weight) logging_dict.update(disc_logging) wandb_log(MISC, logging_dict, step=itr) # Log images if itr % ARGS.log_freq == 0: with torch.no_grad(): if components.type_ == "ae": generator = components.generator x_log = x_t else: generator = components.inn x_log = x_c log_recons(generator=generator, x=x_log, itr=itr) return logging_dict
def train( cfg: Config, encoder: Encoder, context_data: Dataset, num_clusters: int, s_count: int, enc_path: Path, ) -> ClusterResults: # encode the training set with the encoder encoded = encode_dataset(cfg, context_data, encoder) # create data loader with one giant batch data_loader = DataLoader(encoded, batch_size=len(encoded), shuffle=False) encoded, s, y = next(iter(data_loader)) preds = run_kmeans_faiss( encoded, nmb_clusters=num_clusters, cuda=str(cfg.misc._device) != "cpu", n_iter=cfg.clust.epochs, verbose=True, ) cluster_ids = preds.cpu().numpy() # preds, _ = run_kmeans_torch(encoded, num_clusters, device=args._device, n_iter=args.epochs, verbose=True) counts = np.zeros((num_clusters, num_clusters), dtype=np.int64) counts, class_ids = count_occurances(counts, cluster_ids, s, y, s_count, cfg.clust.cluster) _, context_metrics, logging_dict = cluster_metrics( cluster_ids=cluster_ids, counts=counts, true_class_ids=class_ids.numpy(), num_total=preds.size(0), s_count=s_count, to_cluster=cfg.clust.cluster, ) prepared = (f"{k}: {v:.5g}" if isinstance(v, float) else f"{k}: {v}" for k, v in logging_dict.items()) log.info(" | ".join(prepared)) wandb_log(cfg.misc, logging_dict, step=0) log.info("Context metrics:") print_metrics({f"Context {k}": v for k, v in context_metrics.items()}) return ClusterResults( flags=flatten( OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)), cluster_ids=preds, class_ids=get_class_id(s=s, y=y, s_count=s_count, to_cluster=cfg.clust.cluster), enc_path=enc_path, context_metrics=context_metrics, )
def fit( self, train_data: Union[Dataset, DataLoader], epochs: int, device: torch.device, use_wandb: bool, test_data: Optional[Union[Dataset, DataLoader]] = None, pred_s: bool = False, batch_size: int = 256, test_batch_size: int = 1000, lr_milestones: Optional[Dict] = None, ) -> None: if not isinstance(train_data, DataLoader): train_data = DataLoader(train_data, batch_size=batch_size, shuffle=True, pin_memory=True) if test_data is not None: if not isinstance(test_data, DataLoader): test_data = DataLoader(test_data, batch_size=test_batch_size, shuffle=False, pin_memory=True) scheduler = None if lr_milestones is not None: scheduler = MultiStepLR(optimizer=self.optimizer, **lr_milestones) log.info("Training classifier...") pbar = trange(epochs) for epoch in pbar: self.model.train() for step, (x, *target) in enumerate(train_data, start=epoch * len(train_data)): if len(target) == 2: target = target[0] if pred_s else target[1] else: target = target[0] x = x.to(device, non_blocking=True) target = target.to(device, non_blocking=True) self.optimizer.zero_grad() loss, acc = self.routine(x, target) loss.backward() self.optimizer.step() wandb_log(use_wandb, {"loss": loss.item()}, step=step) pbar.set_postfix(epoch=epoch + 1, train_loss=loss.item(), train_acc=acc) if test_data is not None: self.model.eval() avg_test_acc = 0.0 with torch.set_grad_enabled(False): for x, s, y in test_data: if pred_s: target = s else: target = y x = x.to(device) target = target.to(device) loss, acc = self.routine(x, target) avg_test_acc += acc avg_test_acc /= len(test_data) pbar.set_postfix(epoch=epoch + 1, avg_test_acc=avg_test_acc) else: pbar.set_postfix(epoch=epoch + 1) if scheduler is not None: scheduler.step(epoch) pbar.close()
def main(cfg: Config, cluster_label_file: Optional[Path] = None) -> Tuple[Model, Path]: """Main function Args: cluster_label_file: path to a pth file with cluster IDs use_wandb: this arguments overwrites the flag Returns: the trained generator """ # ==== initialize globals ==== global ARGS, CFG, DATA, ENC, MISC ARGS = cfg.clust CFG = cfg DATA = cfg.data ENC = cfg.enc MISC = cfg.misc # ==== current git commit ==== if os.environ.get("STARTED_BY_GUILDAI", None) == "1": sha = "" else: repo = git.Repo(search_parent_directories=True) sha = repo.head.object.hexsha use_gpu = torch.cuda.is_available() and MISC.gpu >= 0 random_seed(MISC.seed, use_gpu) if cluster_label_file is not None: MISC.cluster_label_file = str(cluster_label_file) run = None if MISC.use_wandb: group = "" if MISC.log_method: group += MISC.log_method if MISC.exp_group: group += "." + MISC.exp_group if cfg.bias.log_dataset: group += "." + cfg.bias.log_dataset run = wandb.init( entity="anonymous", project="fcm-hydra", config=flatten( OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)), group=group if group else None, reinit=True, ) save_dir = Path(to_absolute_path(MISC.save_dir)) / str(time.time()) save_dir.mkdir(parents=True, exist_ok=True) log.info(str(OmegaConf.to_yaml(cfg, resolve=True, sort_keys=True))) log.info(f"Save directory: {save_dir.resolve()}") # ==== check GPU ==== MISC._device = f"cuda:{MISC.gpu}" if use_gpu else "cpu" device = torch.device(MISC._device) log.info( f"{torch.cuda.device_count()} GPUs available. Using device '{device}'") # ==== construct dataset ==== datasets: DatasetTriplet = load_dataset(CFG) log.info("Size of context-set: {}, training-set: {}, test-set: {}".format( len(datasets.context), len(datasets.train), len(datasets.test), )) ARGS.test_batch_size = ARGS.test_batch_size if ARGS.test_batch_size else ARGS.batch_size context_batch_size = round(ARGS.batch_size * len(datasets.context) / len(datasets.train)) context_loader = DataLoader( datasets.context, shuffle=True, batch_size=context_batch_size, num_workers=MISC.num_workers, pin_memory=True, ) enc_train_data = ConcatDataset([datasets.context, datasets.train]) if ARGS.encoder == Enc.rotnet: enc_train_loader = DataLoader( RotationPrediction(enc_train_data, apply_all=True), shuffle=True, batch_size=ARGS.batch_size, num_workers=MISC.num_workers, pin_memory=True, collate_fn=adaptive_collate, ) else: enc_train_loader = DataLoader( enc_train_data, shuffle=True, batch_size=ARGS.batch_size, num_workers=MISC.num_workers, pin_memory=True, ) train_loader = DataLoader( datasets.train, shuffle=True, batch_size=ARGS.batch_size, num_workers=MISC.num_workers, pin_memory=True, ) val_loader = DataLoader( datasets.test, shuffle=False, batch_size=ARGS.test_batch_size, num_workers=MISC.num_workers, pin_memory=True, ) # ==== construct networks ==== input_shape = get_data_dim(context_loader) s_count = datasets.s_dim if datasets.s_dim > 1 else 2 y_count = datasets.y_dim if datasets.y_dim > 1 else 2 if ARGS.cluster == CL.s: num_clusters = s_count elif ARGS.cluster == CL.y: num_clusters = y_count else: num_clusters = s_count * y_count log.info( f"Number of clusters: {num_clusters}, accuracy computed with respect to {ARGS.cluster.name}" ) mappings: List[str] = [] for i in range(num_clusters): if ARGS.cluster == CL.s: mappings.append(f"{i}: s = {i}") elif ARGS.cluster == CL.y: mappings.append(f"{i}: y = {i}") else: # class_id = y * s_count + s mappings.append(f"{i}: (y = {i // s_count}, s = {i % s_count})") log.info("class IDs:\n\t" + "\n\t".join(mappings)) feature_group_slices = getattr(datasets.context, "feature_group_slices", None) # ================================= encoder ================================= encoder: Encoder enc_shape: Tuple[int, ...] if ARGS.encoder in (Enc.ae, Enc.vae): encoder, enc_shape = build_ae(CFG, input_shape, feature_group_slices) else: if len(input_shape) < 2: raise ValueError("RotNet can only be applied to image data.") enc_optimizer_kwargs = {"lr": ARGS.enc_lr, "weight_decay": ARGS.enc_wd} enc_kwargs = { "pretrained": False, "num_classes": 4, "zero_init_residual": True } net = resnet18( **enc_kwargs) if DATA.dataset == DS.cmnist else resnet50( **enc_kwargs) encoder = SelfSupervised(model=net, num_classes=4, optimizer_kwargs=enc_optimizer_kwargs) enc_shape = (512, ) encoder.to(device) log.info(f"Encoding shape: {enc_shape}") enc_path: Path if ARGS.enc_path: enc_path = Path(ARGS.enc_path) if ARGS.encoder == Enc.rotnet: assert isinstance(encoder, SelfSupervised) encoder = encoder.get_encoder() save_dict = torch.load(ARGS.enc_path, map_location=lambda storage, loc: storage) encoder.load_state_dict(save_dict["encoder"]) if "args" in save_dict: args_encoder = save_dict["args"] assert ARGS.encoder.name == args_encoder["encoder_type"] assert ENC.levels == args_encoder["levels"] else: encoder.fit(enc_train_loader, epochs=ARGS.enc_epochs, device=device, use_wandb=ARGS.enc_wandb) if ARGS.encoder == Enc.rotnet: assert isinstance(encoder, SelfSupervised) encoder = encoder.get_encoder() # the args names follow the convention of the standalone VAE commandline args args_encoder = { "encoder_type": ARGS.encoder.name, "levels": ENC.levels } enc_path = save_dir.resolve() / "encoder" torch.save({ "encoder": encoder.state_dict(), "args": args_encoder }, enc_path) log.info(f"To make use of this encoder:\n--enc-path {enc_path}") if ARGS.enc_wandb: log.info("Stopping here because W&B will be messed up...") if run is not None: run.finish( ) # this allows multiple experiments in one python process return cluster_label_path = get_cluster_label_path(MISC, save_dir) if ARGS.method == Meth.kmeans: kmeans_results = train_k_means(CFG, encoder, datasets.context, num_clusters, s_count, enc_path) pth = save_results(save_path=cluster_label_path, cluster_results=kmeans_results) if run is not None: run.finish( ) # this allows multiple experiments in one python process return (), pth if ARGS.finetune_encoder: encoder.freeze_initial_layers(ARGS.freeze_layers, { "lr": ARGS.finetune_lr, "weight_decay": ARGS.weight_decay }) # ================================= labeler ================================= pseudo_labeler: PseudoLabeler if ARGS.pseudo_labeler == PL.ranking: pseudo_labeler = RankingStatistics(k_num=ARGS.k_num) elif ARGS.pseudo_labeler == PL.cosine: pseudo_labeler = CosineSimThreshold( upper_threshold=ARGS.upper_threshold, lower_threshold=ARGS.lower_threshold) # ================================= method ================================= method: Method if ARGS.method == Meth.pl_enc: method = PseudoLabelEnc() elif ARGS.method == Meth.pl_output: method = PseudoLabelOutput() elif ARGS.method == Meth.pl_enc_no_norm: method = PseudoLabelEncNoNorm() # ================================= classifier ================================= clf_optimizer_kwargs = {"lr": ARGS.lr, "weight_decay": ARGS.weight_decay} clf_fn = FcNet(hidden_dims=ARGS.cl_hidden_dims) clf_input_shape = (prod(enc_shape), ) # FcNet first flattens the input classifier = build_classifier( input_shape=clf_input_shape, target_dim=s_count if ARGS.use_multi_head else num_clusters, model_fn=clf_fn, optimizer_kwargs=clf_optimizer_kwargs, num_heads=y_count if ARGS.use_multi_head else 1, ) classifier.to(device) model: Union[Model, MultiHeadModel] if ARGS.use_multi_head: labeler_fn: ModelFn if DATA.dataset == DS.cmnist: labeler_fn = Mp32x23Net(batch_norm=True) elif DATA.dataset == DS.celeba: labeler_fn = Mp64x64Net(batch_norm=True) else: labeler_fn = FcNet(hidden_dims=ARGS.labeler_hidden_dims) labeler_optimizer_kwargs = { "lr": ARGS.labeler_lr, "weight_decay": ARGS.labeler_wd } labeler: Classifier = build_classifier( input_shape=input_shape, target_dim=s_count, model_fn=labeler_fn, optimizer_kwargs=labeler_optimizer_kwargs, ) labeler.to(device) log.info("Fitting the labeler to the labeled data.") labeler.fit( train_loader, epochs=ARGS.labeler_epochs, device=device, use_wandb=ARGS.labeler_wandb, ) labeler.eval() model = MultiHeadModel( encoder=encoder, classifiers=classifier, method=method, pseudo_labeler=pseudo_labeler, labeler=labeler, train_encoder=ARGS.finetune_encoder, ) else: model = Model( encoder=encoder, classifier=classifier, method=method, pseudo_labeler=pseudo_labeler, train_encoder=ARGS.finetune_encoder, ) start_epoch = 1 # start at 1 so that the val_freq works correctly # Resume from checkpoint if MISC.resume is not None: log.info("Restoring generator from checkpoint") model, start_epoch = restore_model(CFG, Path(MISC.resume), model) if MISC.evaluate: pth_path = convert_and_save_results( CFG, cluster_label_path, classify_dataset(CFG, model, datasets.context), enc_path=enc_path, context_metrics={}, # TODO: compute this ) if run is not None: run.finish( ) # this allows multiple experiments in one python process return model, pth_path # Logging # wandb.set_model_graph(str(generator)) num_parameters = count_parameters(model) log.info(f"Number of trainable parameters: {num_parameters}") # best_loss = float("inf") best_acc = 0.0 n_vals_without_improvement = 0 # super_val_freq = ARGS.super_val_freq or ARGS.val_freq itr = 0 # Train generator for N epochs for epoch in range(start_epoch, start_epoch + ARGS.epochs): if n_vals_without_improvement > ARGS.early_stopping > 0: break itr = train(model=model, context_data=context_loader, train_data=train_loader, epoch=epoch) if epoch % ARGS.val_freq == 0: val_acc, _, val_log = validate(model, val_loader) if val_acc > best_acc: best_acc = val_acc save_model(CFG, save_dir, model, epoch=epoch, sha=sha, best=True) n_vals_without_improvement = 0 else: n_vals_without_improvement += 1 prepare = (f"{k}: {v:.5g}" if isinstance(v, float) else f"{k}: {v}" for k, v in val_log.items()) log.info("[VAL] Epoch {:04d} | {} | " "No improvement during validation: {:02d}".format( epoch, " | ".join(prepare), n_vals_without_improvement, )) wandb_log(MISC, val_log, step=itr) # if ARGS.super_val and epoch % super_val_freq == 0: # log_metrics(ARGS, model=model.bundle, data=datasets, step=itr) # save_model(args, save_dir, model=model.bundle, epoch=epoch, sha=sha) log.info("Training has finished.") # path = save_model(args, save_dir, model=model, epoch=epoch, sha=sha) # model, _ = restore_model(args, path, model=model) _, test_metrics, _ = validate(model, val_loader) _, context_metrics, _ = validate(model, context_loader) log.info("Test metrics:") print_metrics({f"Test {k}": v for k, v in test_metrics.items()}) log.info("Context metrics:") print_metrics({f"Context {k}": v for k, v in context_metrics.items()}) pth_path = convert_and_save_results( CFG, cluster_label_path=cluster_label_path, results=classify_dataset(CFG, model, datasets.context), enc_path=enc_path, context_metrics=context_metrics, test_metrics=test_metrics, ) if run is not None: run.finish() # this allows multiple experiments in one python process return model, pth_path
def train(model: Model, context_data: DataLoader, train_data: DataLoader, epoch: int) -> int: total_loss_meter = AverageMeter() loss_meters: Optional[Dict[str, AverageMeter]] = None time_meter = AverageMeter() start_epoch_time = time.time() end = start_epoch_time epoch_len = min(len(context_data), len(train_data)) itr = start_itr = (epoch - 1) * epoch_len data_iterator = zip(context_data, train_data) model.train() s_count = MISC._s_dim if MISC._s_dim > 1 else 2 for itr, ((x_c, _, _), (x_t, s_t, y_t)) in enumerate(data_iterator, start=start_itr): x_c, x_t, y_t, s_t = to_device(x_c, x_t, y_t, s_t) if ARGS.with_supervision and not ARGS.use_multi_head: class_id = get_class_id(s=s_t, y=y_t, s_count=s_count, to_cluster=ARGS.cluster) loss_sup, logging_sup = model.supervised_loss( x_t, class_id, ce_weight=ARGS.sup_ce_weight, bce_weight=ARGS.sup_bce_weight) else: loss_sup = x_t.new_zeros(()) logging_sup = {} loss_unsup, logging_unsup = model.unsupervised_loss(x_c) loss = loss_sup + loss_unsup model.zero_grad() loss.backward() model.step() # Log losses logging_dict = {**logging_unsup, **logging_sup} total_loss_meter.update(loss.item()) if loss_meters is None: loss_meters = {name: AverageMeter() for name in logging_dict} for name, value in logging_dict.items(): loss_meters[name].update(value) time_for_batch = time.time() - end time_meter.update(time_for_batch) wandb_log(MISC, logging_dict, step=itr) end = time.time() time_for_epoch = time.time() - start_epoch_time assert loss_meters is not None to_log = "[TRN] Epoch {:04d} | Duration: {} | Batches/s: {:.4g} | {} ({:.5g})".format( epoch, readable_duration(time_for_epoch), 1 / time_meter.avg, " | ".join(f"{name}: {meter.avg:.5g}" for name, meter in loss_meters.items()), total_loss_meter.avg, ) log.info(to_log) return itr