Beispiel #1
0
def main():
    args = arg_parse()

    # ---- setup device ----
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print('==> Using device ' + device)

    # ---- setup config ----
    cfg = get_cfg_defaults()
    cfg.merge_from_file(args.cfg)
    cfg.freeze()
    seed.set_seed(cfg.SOLVER.SEED)

    # ---- setup logger ----
    os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
    logger = logging.construct_logger('context_cnns', cfg.OUTPUT_DIR)
    logger.info(f'Using {device}')
    logger.info('\n' + cfg.dump())

    # ---- setup dataset ----
    train_loader, val_loader = get_cifar(cfg)

    # ---- setup model ----
    print('==> Building model..')
    net = get_model(cfg)
    net = net.to(device)

    model_stats = summary(net, (3, 32, 32), depth=10)
    logger.info('\n'+str(model_stats))

    if device=='cuda':
        net = torch.nn.DataParallel(net)

    # ---- setup trainers ----
    optim = torch.optim.SGD(net.parameters(), lr=cfg.SOLVER.BASE_LR,
                            momentum=cfg.SOLVER.MOMENTUM,
                            weight_decay=cfg.SOLVER.WEIGHT_DECAY)

    trainer = Trainer(
        device,
        train_loader,
        val_loader,
        net,
        optim,
        logger,
        cfg
    )

    if args.resume:
        # Load checkpoint
        print('==> Resuming from checkpoint..')
        cp = torch.load(args.resume)
        trainer.model.load_state_dict(cp['net'])
        trainer.optim.load_state_dict(cp['optim'])
        trainer.epochs = cp['epoch']
        trainer.train_acc = cp['train_accuracy']
        trainer.val_acc = cp['test_accuracy']

    trainer.train()
Beispiel #2
0
def main():
    """The main for this domain adapation example, showing the workflow"""
    args = arg_parse()

    # ---- setup device ----
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("==> Using device " + device)

    # ---- setup configs ----
    cfg = get_cfg_defaults()
    cfg.merge_from_file(args.cfg)
    cfg.freeze()
    set_seed(cfg.SOLVER.SEED)

    # ---- setup logger and output ----
    output_dir = os.path.join(cfg.OUTPUT_DIR, cfg.DATASET.NAME, args.output)
    os.makedirs(output_dir, exist_ok=True)
    logger = construct_logger("isonet", output_dir)
    logger.info("Using " + device)
    logger.info("\n" + cfg.dump())

    # ---- setup dataset ----
    train_loader, val_loader = get_cifar(cfg)

    print("==> Building model..")
    net = get_model(cfg)
    # print(net)
    net = net.to(device)
    # model_stats = summary(net, (3, 32, 32))
    # logger.info('\n'+str(model_stats))

    # Needed even for single GPU https://discuss.pytorch.org/t/attributeerror-net-object-has-no-attribute-module/45652
    if device == "cuda":
        net = torch.nn.DataParallel(net)

    optim = torch.optim.SGD(
        net.parameters(),
        lr=cfg.SOLVER.BASE_LR,
        momentum=cfg.SOLVER.MOMENTUM,
        weight_decay=cfg.SOLVER.WEIGHT_DECAY,
        dampening=cfg.SOLVER.DAMPENING,
        nesterov=cfg.SOLVER.NESTEROV,
    )

    trainer = Trainer(device, train_loader, val_loader, net, optim, logger,
                      output_dir, cfg)

    if args.resume:
        # Load checkpoint
        print("==> Resuming from checkpoint..")
        cp = torch.load(args.resume)
        trainer.model.load_state_dict(cp["net"])
        trainer.optim.load_state_dict(cp["optim"])
        trainer.epochs = cp["epoch"]
        trainer.train_acc = cp["train_accuracy"]
        trainer.val_acc = cp["test_accuracy"]

    trainer.train()
Beispiel #3
0
def main():
    """The main for this multi-source domain adaptation example, showing the workflow"""
    args = arg_parse()

    # ---- setup configs ----
    cfg = get_cfg_defaults()
    cfg.merge_from_file(args.cfg)
    cfg.freeze()
    print(cfg)

    # ---- setup output ----
    format_str = "@%(asctime)s %(name)s [%(levelname)s] - (%(message)s)"
    logging.basicConfig(format=format_str)
    # ---- setup dataset ----
    if type(cfg.DATASET.SOURCE) == list:
        sub_domain_set = cfg.DATASET.SOURCE + [cfg.DATASET.TARGET]
    else:
        sub_domain_set = None
    num_channels = cfg.DATASET.NUM_CHANNELS
    if cfg.DATASET.NAME.upper() == "DIGITS":
        kwargs = {"return_domain_label": True}
    else:
        kwargs = {"download": True, "return_domain_label": True}

    data_access = ImageAccess.get_multi_domain_images(
        cfg.DATASET.NAME.upper(), cfg.DATASET.ROOT, sub_domain_set=sub_domain_set, **kwargs
    )

    # Repeat multiple times to get std
    for i in range(0, cfg.DATASET.NUM_REPEAT):
        seed = cfg.SOLVER.SEED + i * 10
        dataset = MultiDomainAdapDataset(data_access, random_state=seed)
        set_seed(seed)  # seed_everything in pytorch_lightning did not set torch.backends.cudnn
        print(f"==> Building model for seed {seed} ......")
        # ---- setup model and logger ----
        model, train_params = get_model(cfg, dataset, num_channels)

        tb_logger = TensorBoardLogger(cfg.OUTPUT.TB_DIR, name="seed{}".format(seed))
        checkpoint_callback = ModelCheckpoint(
            filename="{epoch}-{step}-{valid_loss:.4f}", monitor="valid_loss", mode="min",
        )
        progress_bar = TQDMProgressBar(cfg.OUTPUT.PB_FRESH)

        trainer = pl.Trainer(
            min_epochs=cfg.SOLVER.MIN_EPOCHS,
            max_epochs=cfg.SOLVER.MAX_EPOCHS,
            callbacks=[checkpoint_callback, progress_bar],
            gpus=args.gpus,
            auto_select_gpus=True,
            logger=tb_logger,  # logger,
            # weights_summary='full',
            fast_dev_run=False,  # True,
        )

        trainer.fit(model)
        trainer.test()
Beispiel #4
0
def main():
    """The main for this domain adaptation example, showing the workflow"""
    args = arg_parse()

    # ---- setup configs ----
    cfg = get_cfg_defaults()
    cfg.merge_from_file(args.cfg)
    cfg.freeze()
    print(cfg)

    # ---- setup output ----
    os.makedirs(cfg.OUTPUT.DIR, exist_ok=True)
    format_str = "@%(asctime)s %(name)s [%(levelname)s] - (%(message)s)"
    logging.basicConfig(format=format_str)
    # ---- setup dataset ----
    source, target, num_channels = VideoDataset.get_source_target(
        VideoDataset(cfg.DATASET.SOURCE.upper()), VideoDataset(cfg.DATASET.TARGET.upper()), cfg
    )
    dataset = MultiDomainDatasets(
        source, target, config_weight_type=cfg.DATASET.WEIGHT_TYPE, config_size_type=cfg.DATASET.SIZE_TYPE
    )

    # Repeat multiple times to get std
    for i in range(0, cfg.DATASET.NUM_REPEAT):
        seed = cfg.SOLVER.SEED + i * 10
        set_seed(seed)  # seed_everything in pytorch_lightning did not set torch.backends.cudnn
        print(f"==> Building model for seed {seed} ......")
        # ---- setup model and logger ----
        model, train_params = get_model(cfg, dataset, num_channels)
        logger, results, checkpoint_callback, test_csv_file = setup_logger(
            train_params, cfg.OUTPUT.DIR, cfg.DAN.METHOD, seed
        )
        tb_logger = pl_loggers.TensorBoardLogger(cfg.OUTPUT.TB_DIR)
        trainer = pl.Trainer(
            progress_bar_refresh_rate=cfg.OUTPUT.PB_FRESH,  # in steps
            min_epochs=cfg.SOLVER.MIN_EPOCHS,
            max_epochs=cfg.SOLVER.MAX_EPOCHS,
            checkpoint_callback=checkpoint_callback,
            # resume_from_checkpoint=last_checkpoint_file,
            gpus=args.gpus,
            logger=tb_logger,  # logger,
            # weights_summary='full',
            fast_dev_run=cfg.OUTPUT.FAST_DEV_RUN,  # True,
        )

        trainer.fit(model)
        results.update(
            is_validation=True, method_name=cfg.DAN.METHOD, seed=seed, metric_values=trainer.callback_metrics,
        )
        # test scores
        trainer.test()
        results.update(
            is_validation=False, method_name=cfg.DAN.METHOD, seed=seed, metric_values=trainer.callback_metrics,
        )
        results.to_csv(test_csv_file)
        results.print_scores(cfg.DAN.METHOD)
Beispiel #5
0
def main():
    """The main for this domain adaptation example, showing the workflow"""
    args = arg_parse()

    # ---- setup configs ----
    cfg = get_cfg_defaults()
    cfg.merge_from_file(args.cfg)
    cfg.freeze()
    print(cfg)

    # ---- setup output ----
    format_str = "@%(asctime)s %(name)s [%(levelname)s] - (%(message)s)"
    logging.basicConfig(format=format_str)
    # ---- setup dataset ----
    source, target, num_channels = DigitDataset.get_source_target(
        DigitDataset(cfg.DATASET.SOURCE.upper()),
        DigitDataset(cfg.DATASET.TARGET.upper()), cfg.DATASET.ROOT)
    dataset = MultiDomainDatasets(
        source,
        target,
        config_weight_type=cfg.DATASET.WEIGHT_TYPE,
        config_size_type=cfg.DATASET.SIZE_TYPE,
        valid_split_ratio=cfg.DATASET.VALID_SPLIT_RATIO,
    )

    # Repeat multiple times to get std
    for i in range(0, cfg.DATASET.NUM_REPEAT):
        seed = cfg.SOLVER.SEED + i * 10
        # seed_everything in pytorch_lightning did not set torch.backends.cudnn
        set_seed(seed)
        print(f"==> Building model for seed {seed} ......")
        # ---- setup model and logger ----
        model, train_params = get_model(cfg, dataset, num_channels)
        tb_logger = pl_loggers.TensorBoardLogger(cfg.OUTPUT.TB_DIR,
                                                 name="seed{}".format(seed))
        checkpoint_callback = ModelCheckpoint(
            filename="{epoch}-{step}-{valid_loss:.4f}",
            monitor="valid_loss",
            mode="min",
        )
        progress_bar = TQDMProgressBar(cfg.OUTPUT.PB_FRESH)

        trainer = pl.Trainer(
            min_epochs=cfg.SOLVER.MIN_EPOCHS,
            max_epochs=cfg.SOLVER.MAX_EPOCHS,
            callbacks=[checkpoint_callback, progress_bar],
            logger=tb_logger,
            gpus=args.gpus,
        )

        trainer.fit(model)
        trainer.test()
Beispiel #6
0
def main():
    args = arg_parse()

    # ---- setup device ----
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("==> Using device " + device)

    # ---- setup config ----
    cfg = get_cfg_defaults()
    cfg.merge_from_file(args.cfg)
    cfg.freeze()
    seed.set_seed(cfg.SOLVER.SEED)

    # ---- setup logger ----
    os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
    logger = logging.construct_logger("context_cnns", cfg.OUTPUT_DIR)
    logger.info(f"Using {device}")
    logger.info("\n" + cfg.dump())

    # ---- setup dataset ----
    train_loader, valid_loader = get_cifar(cfg)

    # ---- setup model ----
    print("==> Building model..")
    net = get_model(cfg)
    net = net.to(device)

    model_stats = summary(net, (3, 32, 32))
    logger.info("\n" + str(model_stats))

    if device == "cuda":
        net = torch.nn.DataParallel(net)

    # ---- setup trainers ----
    optim = torch.optim.SGD(
        net.parameters(), lr=cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM, weight_decay=cfg.SOLVER.WEIGHT_DECAY
    )

    trainer = Trainer(device, train_loader, valid_loader, net, optim, logger, cfg)

    if args.resume:
        # Load checkpoint
        print("==> Resuming from checkpoint..")
        cp = torch.load(args.resume)
        trainer.model.load_state_dict(cp["net"])
        trainer.optim.load_state_dict(cp["optim"])
        trainer.epochs = cp["epoch"]
        trainer.train_acc = cp["train_accuracy"]
        trainer.valid_acc = cp["test_accuracy"]

    trainer.train()
Beispiel #7
0
def main():
    args = arg_parse()
    # ---- setup device ----
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("==> Using device " + device)

    # ---- setup configs ----
    cfg = get_cfg_defaults()
    cfg.merge_from_file(args.cfg)
    cfg.freeze()
    seed.set_seed(cfg.SOLVER.SEED)
    # ---- setup logger and output ----
    output_dir = os.path.join(cfg.OUTPUT_DIR, cfg.DATASET.NAME, args.output)
    os.makedirs(output_dir, exist_ok=True)
    logger = lu.construct_logger("gripnet", output_dir)
    logger.info("Using " + device)
    logger.info(cfg.dump())
    # ---- setup dataset ----
    download_file_by_url(cfg.DATASET.URL, cfg.DATASET.ROOT, "pose.pt", "pt")
    data = torch.load(os.path.join(cfg.DATASET.ROOT, "pose.pt"))
    device = torch.device(device)
    data = data.to(device)
    # ---- setup model ----
    print("==> Building model..")
    model = GripNet(
        cfg.GRIPN.GG_LAYERS, cfg.GRIPN.GD_LAYERS, cfg.GRIPN.DD_LAYERS, data.n_d_node, data.n_g_node, data.n_dd_edge_type
    ).to(device)
    # TODO Visualize model
    # ---- setup trainers ----
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.SOLVER.BASE_LR)
    # TODO
    trainer = Trainer(cfg, device, data, model, optimizer, logger, output_dir)

    if args.resume:
        # Load checkpoint
        print("==> Resuming from checkpoint..")
        cp = torch.load(args.resume)
        trainer.model.load_state_dict(cp["net"])
        trainer.optim.load_state_dict(cp["optim"])
        trainer.epochs = cp["epoch"]
        trainer.train_auprc = cp["train_auprc"]
        trainer.valid_auprc = cp["valid_auprc"]
        trainer.train_auroc = cp["train_auroc"]
        trainer.valid_auroc = cp["valid_auroc"]
        trainer.train_ap = cp["train_ap"]
        trainer.valid_ap = cp["valid_ap"]

    trainer.train()
Beispiel #8
0
def main():
    args = arg_parse()
    # ---- setup device ----
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print('==> Using device ' + device)

    # ---- setup configs ----
    C.merge_from_file(args.cfg)
    C.freeze()
    seed.set_seed(C.SOLVER.SEED)
    # ---- setup logger and output ----
    output_dir = os.path.join(C.OUTPUT_DIR, C.DATASET.NAME, args.output)
    os.makedirs(output_dir, exist_ok=True)
    logger = lu.construct_logger('gripnet', output_dir)
    logger.info('Using ' + device)
    logger.info(C.dump())
    # ---- setup dataset ----
    data = construct_dataset(C)
    device = torch.device(device)
    data = data.to(device)
    # ---- setup model ----
    print('==> Building model..')
    model = GripNet(C.GRIPN.GG_LAYERS, C.GRIPN.GD_LAYERS, C.GRIPN.DD_LAYERS,
                    data.n_d_node, data.n_g_node,
                    data.n_dd_edge_type).to(device)
    # TODO Visualize model
    # ---- setup trainers ----
    optimizer = torch.optim.Adam(model.parameters(), lr=C.SOLVER.BASE_LR)
    # TODO
    trainer = Trainer(C, device, data, model, optimizer, logger, output_dir)

    if args.resume:
        # Load checkpoint
        print('==> Resuming from checkpoint..')
        cp = torch.load(args.resume)
        trainer.model.load_state_dict(cp['net'])
        trainer.optim.load_state_dict(cp['optim'])
        trainer.epochs = cp['epoch']
        trainer.train_auprc = cp['train_auprc']
        trainer.val_auprc = cp['val_auprc']
        trainer.train_auroc = cp['train_auroc']
        trainer.val_auroc = cp['val_auroc']
        trainer.train_ap = cp['train_ap']
        trainer.val_ap = cp['val_ap']

    trainer.train()
Beispiel #9
0
def test_shapes():
    set_seed(36)

    CNN_OUT_HEIGHT = 32
    CNN_OUT_WIDTH = 32
    CNN_OUT_CHANNELS = 256

    cnn = nn.Sequential(
        nn.Conv2d(in_channels=3,
                  out_channels=CNN_OUT_CHANNELS,
                  kernel_size=3,
                  padding=1), nn.MaxPool2d(kernel_size=2))
    cnn_transformer = CNNTransformer(
        cnn=cnn,
        cnn_output_shape=(-1, CNN_OUT_CHANNELS, CNN_OUT_HEIGHT, CNN_OUT_WIDTH),
        num_layers=4,
        num_heads=4,
        dim_feedforward=1024,
        dropout=0.1,
        output_type="spatial",
        positional_encoder=None,
    )

    cnn_transformer.eval()

    BATCH_SIZE = 2

    input_batch = torch.randn((BATCH_SIZE, 3, 64, 64))

    out_spatial = cnn_transformer(input_batch)
    cnn_transformer.output_type = "sequence"
    out_seq = cnn_transformer(input_batch)

    assert out_spatial.size() == (BATCH_SIZE, CNN_OUT_CHANNELS, CNN_OUT_HEIGHT,
                                  CNN_OUT_WIDTH)
    assert out_seq.size() == (CNN_OUT_HEIGHT * CNN_OUT_WIDTH, BATCH_SIZE,
                              CNN_OUT_CHANNELS)

    # simply reshape them both to have same shape
    out_spatial_2 = seq_to_spatial(out_seq, CNN_OUT_HEIGHT, CNN_OUT_WIDTH)

    testing.assert_almost_equal(out_spatial.detach().numpy(),
                                out_spatial_2.detach().numpy())
Beispiel #10
0
def test_set_seed_torch(torch_rand):
    set_seed()
    result = torch.rand(1).item()
    testing.assert_equal(result, torch_rand)
Beispiel #11
0
def test_set_seed_numpy(np_rand):
    set_seed()
    result = np.random.rand()
    testing.assert_equal(result, np_rand)
Beispiel #12
0
def test_set_seed_base(base_rand):
    set_seed()
    result = random.random()
    testing.assert_equal(result, base_rand)
Beispiel #13
0
def main():
    """The main for this domain adaptation example, showing the workflow"""
    args = arg_parse()

    # ---- setup configs ----
    cfg = get_cfg_defaults()
    cfg.merge_from_file(args.cfg)
    cfg.freeze()
    print(cfg)

    # ---- setup output ----
    format_str = "@%(asctime)s %(name)s [%(levelname)s] - (%(message)s)"
    logging.basicConfig(format=format_str)
    # ---- setup dataset ----
    seed = cfg.SOLVER.SEED
    source, target, num_classes = VideoDataset.get_source_target(
        VideoDataset(cfg.DATASET.SOURCE.upper()), VideoDataset(cfg.DATASET.TARGET.upper()), seed, cfg
    )
    dataset = VideoMultiDomainDatasets(
        source,
        target,
        image_modality=cfg.DATASET.IMAGE_MODALITY,
        seed=seed,
        config_weight_type=cfg.DATASET.WEIGHT_TYPE,
        config_size_type=cfg.DATASET.SIZE_TYPE,
    )

    # ---- training/test process ----
    ### Repeat multiple times to get std
    for i in range(0, cfg.DATASET.NUM_REPEAT):
        seed = seed + i * 10
        set_seed(seed)  # seed_everything in pytorch_lightning did not set torch.backends.cudnn
        print(f"==> Building model for seed {seed} ......")
        # ---- setup model and logger ----
        model, train_params = get_model(cfg, dataset, num_classes)
        tb_logger = pl_loggers.TensorBoardLogger(cfg.OUTPUT.TB_DIR, name="seed{}".format(seed))
        checkpoint_callback = ModelCheckpoint(
            # dirpath=full_checkpoint_dir,
            filename="{epoch}-{step}-{valid_loss:.4f}",
            # save_last=True,
            # save_top_k=1,
            monitor="valid_loss",
            mode="min",
        )

        ### Set early stopping
        # early_stop_callback = EarlyStopping(monitor="valid_target_acc", min_delta=0.0000, patience=100, mode="max")

        lr_monitor = LearningRateMonitor(logging_interval="epoch")
        progress_bar = TQDMProgressBar(cfg.OUTPUT.PB_FRESH)

        ### Set the lightning trainer. Comment `limit_train_batches`, `limit_val_batches`, `limit_test_batches` when
        # training. Uncomment and change the ratio to test the code on the smallest sub-dataset for efficiency in
        # debugging. Uncomment early_stop_callback to activate early stopping.
        trainer = pl.Trainer(
            min_epochs=cfg.SOLVER.MIN_EPOCHS,
            max_epochs=cfg.SOLVER.MAX_EPOCHS,
            # resume_from_checkpoint=last_checkpoint_file,
            gpus=args.gpus,
            logger=tb_logger,  # logger,
            # weights_summary='full',
            fast_dev_run=cfg.OUTPUT.FAST_DEV_RUN,  # True,
            callbacks=[lr_monitor, checkpoint_callback, progress_bar],
            # callbacks=[early_stop_callback, lr_monitor],
            # limit_train_batches=0.005,
            # limit_val_batches=0.06,
            # limit_test_batches=0.06,
        )

        ### Find learning_rate
        # lr_finder = trainer.tuner.lr_find(model, max_lr=0.1, min_lr=1e-6)
        # fig = lr_finder.plot(suggest=True)
        # fig.show()
        # logging.info(lr_finder.suggestion())

        ### Training/validation process
        trainer.fit(model)

        ### Test process
        trainer.test()
Beispiel #14
0
# import os
# import sys

# sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

# import kale
# from kale.embed import image_cnn, mpca, positional_encoding, attention_cnn
# from kale.loaddata import cifar_access, dataset_access, digits_access, mnistm, multi_domain, sampler, usps, videos

# # from kale.pipeline import domain_adapter # need pytorch-lightning
# from kale.predict import class_domain_nets, isonet, losses
# from kale.prepdata import image_transform, prep_cmr, tensor_reshape, video_transform
# from kale.utils import csv_logger, logger, seed
import logging

from kale.utils import seed

seed.set_seed(2020)
""" These only work with the optional graph modules
from kale.embed import gcn, gripnet
"""
logging.getLogger().setLevel(logging.INFO)
logging.info("kale imported")
Beispiel #15
0
    "KITCHEN;6;kitchen_train.pkl;kitchen_test.pkl",
]
TARGETS = [
    "EPIC;8;epic_D1_train.pkl;epic_D1_test.pkl",
    # "ADL;7;adl_P_04_train.pkl;adl_P_04_test.pkl",
    # "GTEA;6;gtea_train.pkl;gtea_test.pkl",
    # "KITCHEN;6;kitchen_train.pkl;kitchen_test.pkl",
]
ALL = SOURCES + TARGETS
IMAGE_MODALITY = ["rgb", "flow", "joint"]
WEIGHT_TYPE = ["natural", "balanced", "preset0"]
# DATASIZE_TYPE = ["max", "source"]
DATASIZE_TYPE = ["max"]
VALID_RATIO = [0.1]
seed = 36
set_seed(seed)
CLASS_SUBSETS = [[1, 3, 8]]

root_dir = os.path.dirname(os.path.dirname(os.getcwd()))
url = "https://github.com/pykale/data/raw/main/videos/video_test_data.zip"


@pytest.fixture(scope="module")
def testing_cfg(download_path):
    cfg = CN()
    cfg.DATASET = CN()
    cfg.DATASET.ROOT = root_dir + "/" + download_path + "/video_test_data/"
    cfg.DATASET.IMAGE_MODALITY = "joint"
    cfg.DATASET.FRAMES_PER_SEGMENT = 16
    yield cfg
Beispiel #16
0
import pytest
import torch

from kale.predict.class_domain_nets import (
    ClassNetSmallImage,
    ClassNetVideo,
    ClassNetVideoConv,
    DomainNetSmallImage,
    DomainNetVideo,
    SoftmaxNet,
)
from kale.utils.seed import set_seed

set_seed(36)
BATCH_SIZE = 2
# The default input shape for basic ClassNet and DomainNet is batch_size * dimension. However, for ClassNetVideoConv,
# the input is the output of the I3D last average pooling layer and the shape is
# batch_size * num_channel * frame_per_segment * height * weight.
INPUT_BATCH = torch.randn(BATCH_SIZE, 128)
INPUT_BATCH_AVERAGE = torch.randn(BATCH_SIZE, 1024, 1, 1, 1)
CLASSNET_MODEL = [ClassNetSmallImage, ClassNetVideo]
DOMAINNET_MODEL = [DomainNetSmallImage, DomainNetVideo]


def test_softmaxnet_shapes():
    model = SoftmaxNet(input_dim=128, n_classes=8)
    model.eval()
    output_batch = model(INPUT_BATCH)
    assert output_batch.size() == (BATCH_SIZE, 8)