def log_images(cfg: Config,
               image_batch,
               name,
               step,
               nsamples=64,
               nrows=8,
               monochrome=False,
               prefix=None):
    """Make a grid of the given images, save them in a file and log them with W&B"""
    prefix = "train_" if prefix is None else f"{prefix}_"
    images = image_batch[:nsamples]

    if cfg.enc.recon_loss == RL.ce:
        images = images.argmax(dim=1).float() / 255
    else:
        if cfg.data.dataset in (DS.celeba, DS.genfaces):
            images = 0.5 * images + 0.5

    if monochrome:
        images = images.mean(dim=1, keepdim=True)
    # torchvision.utils.save_image(images, f'./experiments/finn/{prefix}{name}.png', nrow=nrows)
    shw = torchvision.utils.make_grid(images, nrow=nrows).clamp(0, 1).cpu()
    wandb_log(
        cfg.misc,
        {
            prefix + name:
            [wandb.Image(torchvision.transforms.functional.to_pil_image(shw))]
        },
        step=step,
    )
    def fit(self, train_data: DataLoader, epochs: int, device: torch.device,
            use_wandb: bool) -> None:
        self.train()

        step = 0
        logging_dict = {}
        # enc_sched = torch.optim.lr_scheduler.StepLR(self.encoder.optimizer, step_size=9, gamma=.3)
        # dec_sched = torch.optim.lr_scheduler.StepLR(self.decoder.optimizer, step_size=9, gamma=.3)
        with tqdm(total=epochs * len(train_data)) as pbar:
            for _ in range(epochs):

                for x, _, _ in train_data:

                    x = x.to(device)

                    self.zero_grad()
                    _, loss, logging_dict = self.routine(x)

                    loss.backward()
                    self.step()

                    enc_loss: float = loss.item()
                    pbar.update()
                    pbar.set_postfix(AE_loss=enc_loss)
                    if use_wandb:
                        step += 1
                        logging_dict.update({"Total Loss": enc_loss})
                        wandb_log(True, logging_dict, step)
                # enc_sched.step()
                # dec_sched.step()
        log.info("Final result from encoder training:")
        print_metrics({f"Enc {k}": v for k, v in logging_dict.items()})
Ejemplo n.º 3
0
def train_step(
    components: Union["AeComponents", InnComponents],
    context_data_itr: Iterator[Tuple[Tensor, Tensor, Tensor]],
    train_data_itr: Iterator[Tuple[Tensor, Tensor, Tensor]],
    itr: int,
) -> Dict[str, float]:

    disc_weight = 0.0 if itr < ARGS.warmup_steps else ARGS.disc_weight
    if ARGS.disc_method == DM.nn:
        # Train the discriminator on its own for a number of iterations
        for _ in range(ARGS.num_disc_updates):
            x_c, x_t, s_t, y_t = get_batch(context_data_itr=context_data_itr,
                                           train_data_itr=train_data_itr)
            if components.type_ == "ae":
                _, disc_logging = update_disc(x_c, x_t, components,
                                              itr < ARGS.warmup_steps)
            else:
                update_disc_on_inn(ARGS, x_c, x_t, components,
                                   itr < ARGS.warmup_steps)

    x_c, x_t, s_t, y_t = get_batch(context_data_itr=context_data_itr,
                                   train_data_itr=train_data_itr)
    if components.type_ == "ae":
        _, logging_dict = update(x_c=x_c,
                                 x_t=x_t,
                                 s_t=s_t,
                                 y_t=y_t,
                                 ae=components,
                                 warmup=itr < ARGS.warmup_steps)
    else:
        _, logging_dict = update_inn(args=ARGS,
                                     x_c=x_c,
                                     x_t=x_t,
                                     models=components,
                                     disc_weight=disc_weight)

    logging_dict.update(disc_logging)
    wandb_log(MISC, logging_dict, step=itr)

    # Log images
    if itr % ARGS.log_freq == 0:
        with torch.no_grad():
            if components.type_ == "ae":
                generator = components.generator
                x_log = x_t
            else:
                generator = components.inn
                x_log = x_c
            log_recons(generator=generator, x=x_log, itr=itr)
    return logging_dict
def train(
    cfg: Config,
    encoder: Encoder,
    context_data: Dataset,
    num_clusters: int,
    s_count: int,
    enc_path: Path,
) -> ClusterResults:
    # encode the training set with the encoder
    encoded = encode_dataset(cfg, context_data, encoder)
    # create data loader with one giant batch
    data_loader = DataLoader(encoded, batch_size=len(encoded), shuffle=False)
    encoded, s, y = next(iter(data_loader))
    preds = run_kmeans_faiss(
        encoded,
        nmb_clusters=num_clusters,
        cuda=str(cfg.misc._device) != "cpu",
        n_iter=cfg.clust.epochs,
        verbose=True,
    )
    cluster_ids = preds.cpu().numpy()
    # preds, _ = run_kmeans_torch(encoded, num_clusters, device=args._device, n_iter=args.epochs, verbose=True)
    counts = np.zeros((num_clusters, num_clusters), dtype=np.int64)
    counts, class_ids = count_occurances(counts, cluster_ids, s, y, s_count,
                                         cfg.clust.cluster)
    _, context_metrics, logging_dict = cluster_metrics(
        cluster_ids=cluster_ids,
        counts=counts,
        true_class_ids=class_ids.numpy(),
        num_total=preds.size(0),
        s_count=s_count,
        to_cluster=cfg.clust.cluster,
    )
    prepared = (f"{k}: {v:.5g}" if isinstance(v, float) else f"{k}: {v}"
                for k, v in logging_dict.items())
    log.info(" | ".join(prepared))
    wandb_log(cfg.misc, logging_dict, step=0)
    log.info("Context metrics:")
    print_metrics({f"Context {k}": v for k, v in context_metrics.items()})
    return ClusterResults(
        flags=flatten(
            OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)),
        cluster_ids=preds,
        class_ids=get_class_id(s=s,
                               y=y,
                               s_count=s_count,
                               to_cluster=cfg.clust.cluster),
        enc_path=enc_path,
        context_metrics=context_metrics,
    )
    def fit(
        self,
        train_data: Union[Dataset, DataLoader],
        epochs: int,
        device: torch.device,
        use_wandb: bool,
        test_data: Optional[Union[Dataset, DataLoader]] = None,
        pred_s: bool = False,
        batch_size: int = 256,
        test_batch_size: int = 1000,
        lr_milestones: Optional[Dict] = None,
    ) -> None:
        if not isinstance(train_data, DataLoader):
            train_data = DataLoader(train_data,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    pin_memory=True)
        if test_data is not None:
            if not isinstance(test_data, DataLoader):
                test_data = DataLoader(test_data,
                                       batch_size=test_batch_size,
                                       shuffle=False,
                                       pin_memory=True)

        scheduler = None
        if lr_milestones is not None:
            scheduler = MultiStepLR(optimizer=self.optimizer, **lr_milestones)

        log.info("Training classifier...")
        pbar = trange(epochs)
        for epoch in pbar:
            self.model.train()
            for step, (x, *target) in enumerate(train_data,
                                                start=epoch * len(train_data)):
                if len(target) == 2:
                    target = target[0] if pred_s else target[1]
                else:
                    target = target[0]

                x = x.to(device, non_blocking=True)
                target = target.to(device, non_blocking=True)

                self.optimizer.zero_grad()
                loss, acc = self.routine(x, target)
                loss.backward()
                self.optimizer.step()
                wandb_log(use_wandb, {"loss": loss.item()}, step=step)
                pbar.set_postfix(epoch=epoch + 1,
                                 train_loss=loss.item(),
                                 train_acc=acc)

            if test_data is not None:

                self.model.eval()
                avg_test_acc = 0.0

                with torch.set_grad_enabled(False):
                    for x, s, y in test_data:

                        if pred_s:
                            target = s
                        else:
                            target = y

                        x = x.to(device)
                        target = target.to(device)

                        loss, acc = self.routine(x, target)
                        avg_test_acc += acc

                avg_test_acc /= len(test_data)

                pbar.set_postfix(epoch=epoch + 1, avg_test_acc=avg_test_acc)
            else:
                pbar.set_postfix(epoch=epoch + 1)

            if scheduler is not None:
                scheduler.step(epoch)
        pbar.close()
def main(cfg: Config,
         cluster_label_file: Optional[Path] = None) -> Tuple[Model, Path]:
    """Main function

    Args:
        cluster_label_file: path to a pth file with cluster IDs
        use_wandb: this arguments overwrites the flag

    Returns:
        the trained generator
    """
    # ==== initialize globals ====
    global ARGS, CFG, DATA, ENC, MISC
    ARGS = cfg.clust
    CFG = cfg
    DATA = cfg.data
    ENC = cfg.enc
    MISC = cfg.misc

    # ==== current git commit ====
    if os.environ.get("STARTED_BY_GUILDAI", None) == "1":
        sha = ""
    else:
        repo = git.Repo(search_parent_directories=True)
        sha = repo.head.object.hexsha

    use_gpu = torch.cuda.is_available() and MISC.gpu >= 0
    random_seed(MISC.seed, use_gpu)
    if cluster_label_file is not None:
        MISC.cluster_label_file = str(cluster_label_file)

    run = None
    if MISC.use_wandb:
        group = ""
        if MISC.log_method:
            group += MISC.log_method
        if MISC.exp_group:
            group += "." + MISC.exp_group
        if cfg.bias.log_dataset:
            group += "." + cfg.bias.log_dataset
        run = wandb.init(
            entity="anonymous",
            project="fcm-hydra",
            config=flatten(
                OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)),
            group=group if group else None,
            reinit=True,
        )

    save_dir = Path(to_absolute_path(MISC.save_dir)) / str(time.time())
    save_dir.mkdir(parents=True, exist_ok=True)

    log.info(str(OmegaConf.to_yaml(cfg, resolve=True, sort_keys=True)))
    log.info(f"Save directory: {save_dir.resolve()}")
    # ==== check GPU ====
    MISC._device = f"cuda:{MISC.gpu}" if use_gpu else "cpu"
    device = torch.device(MISC._device)
    log.info(
        f"{torch.cuda.device_count()} GPUs available. Using device '{device}'")

    # ==== construct dataset ====
    datasets: DatasetTriplet = load_dataset(CFG)
    log.info("Size of context-set: {}, training-set: {}, test-set: {}".format(
        len(datasets.context),
        len(datasets.train),
        len(datasets.test),
    ))
    ARGS.test_batch_size = ARGS.test_batch_size if ARGS.test_batch_size else ARGS.batch_size
    context_batch_size = round(ARGS.batch_size * len(datasets.context) /
                               len(datasets.train))
    context_loader = DataLoader(
        datasets.context,
        shuffle=True,
        batch_size=context_batch_size,
        num_workers=MISC.num_workers,
        pin_memory=True,
    )
    enc_train_data = ConcatDataset([datasets.context, datasets.train])
    if ARGS.encoder == Enc.rotnet:
        enc_train_loader = DataLoader(
            RotationPrediction(enc_train_data, apply_all=True),
            shuffle=True,
            batch_size=ARGS.batch_size,
            num_workers=MISC.num_workers,
            pin_memory=True,
            collate_fn=adaptive_collate,
        )
    else:
        enc_train_loader = DataLoader(
            enc_train_data,
            shuffle=True,
            batch_size=ARGS.batch_size,
            num_workers=MISC.num_workers,
            pin_memory=True,
        )

    train_loader = DataLoader(
        datasets.train,
        shuffle=True,
        batch_size=ARGS.batch_size,
        num_workers=MISC.num_workers,
        pin_memory=True,
    )
    val_loader = DataLoader(
        datasets.test,
        shuffle=False,
        batch_size=ARGS.test_batch_size,
        num_workers=MISC.num_workers,
        pin_memory=True,
    )

    # ==== construct networks ====
    input_shape = get_data_dim(context_loader)
    s_count = datasets.s_dim if datasets.s_dim > 1 else 2
    y_count = datasets.y_dim if datasets.y_dim > 1 else 2
    if ARGS.cluster == CL.s:
        num_clusters = s_count
    elif ARGS.cluster == CL.y:
        num_clusters = y_count
    else:
        num_clusters = s_count * y_count
    log.info(
        f"Number of clusters: {num_clusters}, accuracy computed with respect to {ARGS.cluster.name}"
    )
    mappings: List[str] = []
    for i in range(num_clusters):
        if ARGS.cluster == CL.s:
            mappings.append(f"{i}: s = {i}")
        elif ARGS.cluster == CL.y:
            mappings.append(f"{i}: y = {i}")
        else:
            # class_id = y * s_count + s
            mappings.append(f"{i}: (y = {i // s_count}, s = {i % s_count})")
    log.info("class IDs:\n\t" + "\n\t".join(mappings))
    feature_group_slices = getattr(datasets.context, "feature_group_slices",
                                   None)

    # ================================= encoder =================================
    encoder: Encoder
    enc_shape: Tuple[int, ...]
    if ARGS.encoder in (Enc.ae, Enc.vae):
        encoder, enc_shape = build_ae(CFG, input_shape, feature_group_slices)
    else:
        if len(input_shape) < 2:
            raise ValueError("RotNet can only be applied to image data.")
        enc_optimizer_kwargs = {"lr": ARGS.enc_lr, "weight_decay": ARGS.enc_wd}
        enc_kwargs = {
            "pretrained": False,
            "num_classes": 4,
            "zero_init_residual": True
        }
        net = resnet18(
            **enc_kwargs) if DATA.dataset == DS.cmnist else resnet50(
                **enc_kwargs)

        encoder = SelfSupervised(model=net,
                                 num_classes=4,
                                 optimizer_kwargs=enc_optimizer_kwargs)
        enc_shape = (512, )
        encoder.to(device)

    log.info(f"Encoding shape: {enc_shape}")

    enc_path: Path
    if ARGS.enc_path:
        enc_path = Path(ARGS.enc_path)
        if ARGS.encoder == Enc.rotnet:
            assert isinstance(encoder, SelfSupervised)
            encoder = encoder.get_encoder()
        save_dict = torch.load(ARGS.enc_path,
                               map_location=lambda storage, loc: storage)
        encoder.load_state_dict(save_dict["encoder"])
        if "args" in save_dict:
            args_encoder = save_dict["args"]
            assert ARGS.encoder.name == args_encoder["encoder_type"]
            assert ENC.levels == args_encoder["levels"]
    else:
        encoder.fit(enc_train_loader,
                    epochs=ARGS.enc_epochs,
                    device=device,
                    use_wandb=ARGS.enc_wandb)
        if ARGS.encoder == Enc.rotnet:
            assert isinstance(encoder, SelfSupervised)
            encoder = encoder.get_encoder()
        # the args names follow the convention of the standalone VAE commandline args
        args_encoder = {
            "encoder_type": ARGS.encoder.name,
            "levels": ENC.levels
        }
        enc_path = save_dir.resolve() / "encoder"
        torch.save({
            "encoder": encoder.state_dict(),
            "args": args_encoder
        }, enc_path)
        log.info(f"To make use of this encoder:\n--enc-path {enc_path}")
        if ARGS.enc_wandb:
            log.info("Stopping here because W&B will be messed up...")
            if run is not None:
                run.finish(
                )  # this allows multiple experiments in one python process
            return

    cluster_label_path = get_cluster_label_path(MISC, save_dir)
    if ARGS.method == Meth.kmeans:
        kmeans_results = train_k_means(CFG, encoder, datasets.context,
                                       num_clusters, s_count, enc_path)
        pth = save_results(save_path=cluster_label_path,
                           cluster_results=kmeans_results)
        if run is not None:
            run.finish(
            )  # this allows multiple experiments in one python process
        return (), pth
    if ARGS.finetune_encoder:
        encoder.freeze_initial_layers(ARGS.freeze_layers, {
            "lr": ARGS.finetune_lr,
            "weight_decay": ARGS.weight_decay
        })

    # ================================= labeler =================================
    pseudo_labeler: PseudoLabeler
    if ARGS.pseudo_labeler == PL.ranking:
        pseudo_labeler = RankingStatistics(k_num=ARGS.k_num)
    elif ARGS.pseudo_labeler == PL.cosine:
        pseudo_labeler = CosineSimThreshold(
            upper_threshold=ARGS.upper_threshold,
            lower_threshold=ARGS.lower_threshold)

    # ================================= method =================================
    method: Method
    if ARGS.method == Meth.pl_enc:
        method = PseudoLabelEnc()
    elif ARGS.method == Meth.pl_output:
        method = PseudoLabelOutput()
    elif ARGS.method == Meth.pl_enc_no_norm:
        method = PseudoLabelEncNoNorm()

    # ================================= classifier =================================
    clf_optimizer_kwargs = {"lr": ARGS.lr, "weight_decay": ARGS.weight_decay}
    clf_fn = FcNet(hidden_dims=ARGS.cl_hidden_dims)
    clf_input_shape = (prod(enc_shape), )  # FcNet first flattens the input

    classifier = build_classifier(
        input_shape=clf_input_shape,
        target_dim=s_count if ARGS.use_multi_head else num_clusters,
        model_fn=clf_fn,
        optimizer_kwargs=clf_optimizer_kwargs,
        num_heads=y_count if ARGS.use_multi_head else 1,
    )
    classifier.to(device)

    model: Union[Model, MultiHeadModel]
    if ARGS.use_multi_head:
        labeler_fn: ModelFn
        if DATA.dataset == DS.cmnist:
            labeler_fn = Mp32x23Net(batch_norm=True)
        elif DATA.dataset == DS.celeba:
            labeler_fn = Mp64x64Net(batch_norm=True)
        else:
            labeler_fn = FcNet(hidden_dims=ARGS.labeler_hidden_dims)

        labeler_optimizer_kwargs = {
            "lr": ARGS.labeler_lr,
            "weight_decay": ARGS.labeler_wd
        }
        labeler: Classifier = build_classifier(
            input_shape=input_shape,
            target_dim=s_count,
            model_fn=labeler_fn,
            optimizer_kwargs=labeler_optimizer_kwargs,
        )
        labeler.to(device)
        log.info("Fitting the labeler to the labeled data.")
        labeler.fit(
            train_loader,
            epochs=ARGS.labeler_epochs,
            device=device,
            use_wandb=ARGS.labeler_wandb,
        )
        labeler.eval()
        model = MultiHeadModel(
            encoder=encoder,
            classifiers=classifier,
            method=method,
            pseudo_labeler=pseudo_labeler,
            labeler=labeler,
            train_encoder=ARGS.finetune_encoder,
        )
    else:
        model = Model(
            encoder=encoder,
            classifier=classifier,
            method=method,
            pseudo_labeler=pseudo_labeler,
            train_encoder=ARGS.finetune_encoder,
        )

    start_epoch = 1  # start at 1 so that the val_freq works correctly
    # Resume from checkpoint
    if MISC.resume is not None:
        log.info("Restoring generator from checkpoint")
        model, start_epoch = restore_model(CFG, Path(MISC.resume), model)
        if MISC.evaluate:
            pth_path = convert_and_save_results(
                CFG,
                cluster_label_path,
                classify_dataset(CFG, model, datasets.context),
                enc_path=enc_path,
                context_metrics={},  # TODO: compute this
            )
            if run is not None:
                run.finish(
                )  # this allows multiple experiments in one python process
            return model, pth_path

    # Logging
    # wandb.set_model_graph(str(generator))
    num_parameters = count_parameters(model)
    log.info(f"Number of trainable parameters: {num_parameters}")

    # best_loss = float("inf")
    best_acc = 0.0
    n_vals_without_improvement = 0
    # super_val_freq = ARGS.super_val_freq or ARGS.val_freq

    itr = 0
    # Train generator for N epochs
    for epoch in range(start_epoch, start_epoch + ARGS.epochs):
        if n_vals_without_improvement > ARGS.early_stopping > 0:
            break

        itr = train(model=model,
                    context_data=context_loader,
                    train_data=train_loader,
                    epoch=epoch)

        if epoch % ARGS.val_freq == 0:
            val_acc, _, val_log = validate(model, val_loader)

            if val_acc > best_acc:
                best_acc = val_acc
                save_model(CFG,
                           save_dir,
                           model,
                           epoch=epoch,
                           sha=sha,
                           best=True)
                n_vals_without_improvement = 0
            else:
                n_vals_without_improvement += 1

            prepare = (f"{k}: {v:.5g}" if isinstance(v, float) else f"{k}: {v}"
                       for k, v in val_log.items())
            log.info("[VAL] Epoch {:04d} | {} | "
                     "No improvement during validation: {:02d}".format(
                         epoch,
                         " | ".join(prepare),
                         n_vals_without_improvement,
                     ))
            wandb_log(MISC, val_log, step=itr)
        # if ARGS.super_val and epoch % super_val_freq == 0:
        #     log_metrics(ARGS, model=model.bundle, data=datasets, step=itr)
        #     save_model(args, save_dir, model=model.bundle, epoch=epoch, sha=sha)

    log.info("Training has finished.")
    # path = save_model(args, save_dir, model=model, epoch=epoch, sha=sha)
    # model, _ = restore_model(args, path, model=model)
    _, test_metrics, _ = validate(model, val_loader)
    _, context_metrics, _ = validate(model, context_loader)
    log.info("Test metrics:")
    print_metrics({f"Test {k}": v for k, v in test_metrics.items()})
    log.info("Context metrics:")
    print_metrics({f"Context {k}": v for k, v in context_metrics.items()})
    pth_path = convert_and_save_results(
        CFG,
        cluster_label_path=cluster_label_path,
        results=classify_dataset(CFG, model, datasets.context),
        enc_path=enc_path,
        context_metrics=context_metrics,
        test_metrics=test_metrics,
    )
    if run is not None:
        run.finish()  # this allows multiple experiments in one python process
    return model, pth_path
def train(model: Model, context_data: DataLoader, train_data: DataLoader,
          epoch: int) -> int:
    total_loss_meter = AverageMeter()
    loss_meters: Optional[Dict[str, AverageMeter]] = None

    time_meter = AverageMeter()
    start_epoch_time = time.time()
    end = start_epoch_time
    epoch_len = min(len(context_data), len(train_data))
    itr = start_itr = (epoch - 1) * epoch_len
    data_iterator = zip(context_data, train_data)
    model.train()
    s_count = MISC._s_dim if MISC._s_dim > 1 else 2

    for itr, ((x_c, _, _), (x_t, s_t, y_t)) in enumerate(data_iterator,
                                                         start=start_itr):

        x_c, x_t, y_t, s_t = to_device(x_c, x_t, y_t, s_t)

        if ARGS.with_supervision and not ARGS.use_multi_head:
            class_id = get_class_id(s=s_t,
                                    y=y_t,
                                    s_count=s_count,
                                    to_cluster=ARGS.cluster)
            loss_sup, logging_sup = model.supervised_loss(
                x_t,
                class_id,
                ce_weight=ARGS.sup_ce_weight,
                bce_weight=ARGS.sup_bce_weight)
        else:
            loss_sup = x_t.new_zeros(())
            logging_sup = {}
        loss_unsup, logging_unsup = model.unsupervised_loss(x_c)
        loss = loss_sup + loss_unsup

        model.zero_grad()
        loss.backward()
        model.step()

        # Log losses
        logging_dict = {**logging_unsup, **logging_sup}
        total_loss_meter.update(loss.item())
        if loss_meters is None:
            loss_meters = {name: AverageMeter() for name in logging_dict}
        for name, value in logging_dict.items():
            loss_meters[name].update(value)

        time_for_batch = time.time() - end
        time_meter.update(time_for_batch)

        wandb_log(MISC, logging_dict, step=itr)
        end = time.time()

    time_for_epoch = time.time() - start_epoch_time
    assert loss_meters is not None
    to_log = "[TRN] Epoch {:04d} | Duration: {} | Batches/s: {:.4g} | {} ({:.5g})".format(
        epoch,
        readable_duration(time_for_epoch),
        1 / time_meter.avg,
        " | ".join(f"{name}: {meter.avg:.5g}"
                   for name, meter in loss_meters.items()),
        total_loss_meter.avg,
    )
    log.info(to_log)
    return itr