class Config(object):
    gpu_id = 0
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

    num_workers = 16

    num_way = 5
    num_shot = 1
    # batch_size = 64

    val_freq = 20
    episode_size = 15
    test_episode = 600

    ic_ratio = 1
    knn = 50

    learning_rate = 0.01
    loss_fsl_ratio = 1.0
    loss_ic_ratio = 1.0

    ###############################################################################################
    # resnet = resnet18
    resnet = resnet34

    # modify_head = False
    modify_head = True

    # matching_net, net_name, batch_size = MatchingNet(hid_dim=64, z_dim=64), "conv4", 64
    matching_net, net_name, batch_size = ResNet12Small(avg_pool=True, drop_rate=0.1), "resnet12", 32

    ic_times = 2

    ic_out_dim = 512

    train_epoch = 1200
    first_epoch, t_epoch = 400, 200
    adjust_learning_rate = RunnerTool.adjust_learning_rate1
    ###############################################################################################

    model_name = "{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}{}".format(
        gpu_id, net_name, train_epoch, batch_size, num_way, num_shot, first_epoch, t_epoch,
        ic_out_dim, ic_ratio, loss_fsl_ratio, loss_ic_ratio, ic_times, "_head" if modify_head else "")

    if "Linux" in platform.platform():
        data_root = '/mnt/4T/Data/data/UFSL/CUB'
        if not os.path.isdir(data_root):
            data_root = '/media/ubuntu/4T/ALISURE/Data/UFSL/CUB'
    else:
        data_root = "F:\\data\\CUB"
    data_root = os.path.join(data_root, "CUBSeg")

    _root_path = "../cub/models_mn/two_ic_ufsl_2net_res_sgd_acc_duli"
    mn_dir = Tools.new_dir("{}/{}_mn.pkl".format(_root_path, model_name))
    ic_dir = Tools.new_dir("{}/{}_ic.pkl".format(_root_path, model_name))

    Tools.print(model_name)
    Tools.print(data_root)
    pass
class Config(object):
    gpu_id = 0
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

    learning_rate = 0.01
    num_workers = 16

    num_way = 5
    num_shot = 1
    # batch_size = 64
    batch_size = 32

    val_freq = 20
    episode_size = 15
    test_episode = 600

    train_epoch = 500
    first_epoch, t_epoch = 300, 100
    adjust_learning_rate = RunnerTool.adjust_learning_rate2

    ###############################################################################################
    # num_way = 10
    num_way_test = 5
    ###############################################################################################

    model_name = "{}_{}_{}_{}_{}_{}_{}".format(gpu_id, train_epoch, batch_size,
                                               num_way, num_shot, first_epoch,
                                               t_epoch)

    # matching_net, model_name = MatchingNet(hid_dim=64, z_dim=64), "{}_{}".format(model_name, "conv4")
    matching_net, model_name = ResNet12Small(avg_pool=True,
                                             drop_rate=0.1), "{}_{}".format(
                                                 model_name, "res12")

    mn_dir = Tools.new_dir(
        "../cub/models_mn/fsl_sgd_modify/{}.pkl".format(model_name))
    if "Linux" in platform.platform():
        data_root = '/mnt/4T/Data/data/UFSL/CUB'
        if not os.path.isdir(data_root):
            data_root = '/media/ubuntu/4T/ALISURE/Data/UFSL/CUB'
    else:
        data_root = "F:\\data\\CUB"

    Tools.print(model_name)
    Tools.print(data_root)
    Tools.print(mn_dir)
    pass
class Config(object):
    gpu_id = "0,1,2,3"
    # gpu_id = "1"
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

    train_epoch = 100
    train_epoch_lr = [50, 80]
    learning_rate = 0.001
    num_workers = 16
    # train_epoch = 300
    # train_epoch_lr = [200, 250]

    num_way = 5
    num_shot = 1
    # batch_size = 256
    # batch_size = 128
    batch_size = 64
    # batch_size = 32

    val_freq = 5
    episode_size = 15
    test_episode = 600

    load_data = False

    model_name = "{}_{}_{}_{}_{}".format(gpu_id.replace(",", ""), train_epoch, batch_size, num_way, num_shot)

    # matching_net, model_name = MatchingNet(hid_dim=64, z_dim=64), "{}_{}".format(model_name, "conv4")
    matching_net, model_name = ResNet12Small(avg_pool=True, drop_rate=0.1), "{}_{}".format(model_name, "res12")

    mn_dir = Tools.new_dir("../tiered_imagenet/models_mn/fsl_modify/{}.pkl".format(model_name))
    if "Linux" in platform.platform():
        data_root = '/mnt/4T/Data/data/UFSL/tiered-imagenet'
        if not os.path.isdir(data_root):
            data_root = '/media/ubuntu/4T/ALISURE/Data/UFSL/tiered-imagenet'
    else:
        data_root = "F:\\data\\tiered-imagenet"

    Tools.print(model_name)
    Tools.print(data_root)
    Tools.print(mn_dir)
    pass
class Config(object):
    gpu_id = 1

    #######################################################################################
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    num_workers = 8

    ic_ratio = 1
    ic_knn = 100
    # ic_out_dim = 1024
    ic_val_freq = 10
    ic_learning_rate = 0.01
    ic_train_epoch = 1500
    ic_first_epoch, ic_t_epoch = 300, 200
    ic_batch_size = 64
    ic_adjust_learning_rate = RunnerTool.adjust_learning_rate1

    fsl_num_way = 5
    fsl_num_shot = 1
    fsl_episode_size = 15
    fsl_test_episode = 600
    fsl_val_freq = 10
    fsl_learning_rate = 0.01
    ###############################################################################################

    #######################################################################################
    dataset_name = "CIFARFS"
    # dataset_name = "FC100"
    # ic_resnet, ic_modify_head, ic_net_name = resnet18, False, "res18"
    ic_resnet, ic_modify_head, ic_net_name = resnet34, True, "res34_head"

    if dataset_name == "CIFARFS":
        aug_name = 1  # other
        # aug_name = 2  # my

        # fsl_matching_net, fsl_net_name, fsl_batch_size = MatchingNet(hid_dim=64, z_dim=64), "conv4", 64
        # fsl_train_epoch, fsl_lr_schedule = 300, [150, 250]
        fsl_matching_net, fsl_net_name, fsl_batch_size = ResNet12Small(
            avg_pool=True, drop_rate=0.1), "resnet12", 32
        fsl_train_epoch, fsl_lr_schedule = 300, [150, 250]

        # ic_dir_checkpoint = None

        # ic_out_dim = 1024
        # ic_resnet, ic_modify_head, ic_net_name = resnet34, True, "res34_head"
        # ic_dir_checkpoint = "../models_CIFARFS/models/ic_res_xx/0_32_resnet_34_64_1024_1_1500_300_200_True_ic.pkl"

        # ic_out_dim = 512
        # ic_resnet, ic_modify_head, ic_net_name = resnet34, True, "res34_head"
        # ic_dir_checkpoint = "../models_CIFARFS/models/ic_res_xx/1_CIFARFS_32_resnet_34_64_512_1_1500_300_200_True_ic.pkl"

        ic_out_dim = 256
        ic_resnet, ic_modify_head, ic_net_name = resnet50, True, "res50_head"
        ic_dir_checkpoint = "../models_CIFARFS/models/ic_res_xx/2_CIFARFS_32_resnet_50_64_256_1_1500_300_200_True_ic.pkl"

        # ic_out_dim = 256
        # ic_resnet, ic_modify_head, ic_net_name = resnet34, True, "res34_head"
        # ic_dir_checkpoint = "../models_CIFARFS/models/ic_res_xx/1_CIFARFS_32_resnet_34_64_256_1_1500_300_200_True_ic.pkl"
    else:
        aug_name = 1  # other
        # aug_name = 2  # my

        # fsl_matching_net, fsl_net_name, fsl_batch_size = MatchingNet(hid_dim=64, z_dim=64), "conv4", 64
        # fsl_train_epoch, fsl_lr_schedule = 400, [200, 300]
        fsl_matching_net, fsl_net_name, fsl_batch_size = ResNet12Small(
            avg_pool=True, drop_rate=0.1), "resnet12", 32
        fsl_train_epoch, fsl_lr_schedule = 200, [100, 150]

        ic_out_dim = 512
        # ic_dir_checkpoint = None
        ic_dir_checkpoint = "../models_CIFARFS/models/ic_res_xx/1_FC100_32_resnet_34_64_512_1_1500_300_200_True_ic.pkl"
        pass
    ###############################################################################################

    model_name = "{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_aug{}".format(
        gpu_id, dataset_name, 32, ic_net_name, ic_train_epoch, ic_batch_size,
        ic_out_dim, fsl_net_name, fsl_train_epoch, fsl_num_way, fsl_num_shot,
        fsl_batch_size, aug_name)

    if "Linux" in platform.platform():
        data_root = '/mnt/4T/Data/data/UFSL/{}'.format(dataset_name)
        if not os.path.isdir(data_root):
            data_root = '/media/ubuntu/4T/ALISURE/Data/UFSL/{}'.format(
                dataset_name)
    else:
        data_root = "F:\\data\\{}".format(dataset_name)

    _root_path = "../models_CIFARFS/mn/two_ic_ufsl_2net_res_sgd_acc_duli_nete"
    mn_dir = Tools.new_dir("{}/{}_mn.pkl".format(_root_path, model_name))
    ic_dir = Tools.new_dir("{}/{}_ic.pkl".format(_root_path, model_name))

    Tools.print(model_name)
    Tools.print(data_root)
    pass
class Config(object):
    gpu_id = 3
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

    learning_rate = 0.001
    num_workers = 16
    # num_way = 5
    # num_way_test = 5
    # val_freq = 10
    num_shot = 1
    episode_size = 15
    test_episode = 600

    ##############################################################################################################
    dataset_name = "CIFARFS"
    # dataset_name = "FC100"

    if dataset_name == "CIFARFS":
        is_large = True
        # matching_net, net_name, batch_size = MatchingNet(hid_dim=64, z_dim=64), "conv4", 64
        matching_net = ResNet12Small(avg_pool=True,
                                     drop_rate=0.1,
                                     large=is_large)
        net_name, batch_size = "res12{}".format(
            "large" if is_large else ""), 64
        aug_name = 1  # other
        train_epoch = 400
        train_epoch_lr = [200, 300]
        val_freq = 10
        num_way = 5
        num_way_test = 5
    else:
        # matching_net, net_name, batch_size = MatchingNet(hid_dim=64, z_dim=64), "conv4", 64
        matching_net, net_name, batch_size = ResNet12Small(
            avg_pool=True, drop_rate=0.1), "res12", 64
        aug_name = 1  # other
        # aug_name = 2  # my
        train_epoch = 60
        train_epoch_lr = [30, 50]
        val_freq = 2
        # num_way = 20
        num_way = 5
        num_way_test = 5
        pass
    ##############################################################################################################

    model_name = "{}_{}_{}_{}_{}_{}_{}_aug{}_{}".format(
        gpu_id, dataset_name, 32, net_name, train_epoch, num_way, num_shot,
        aug_name, val_freq)

    mn_dir = Tools.new_dir(
        "../models_CIFARFS/mn/fsl_modify/{}.pkl".format(model_name))
    log_file = mn_dir.replace(".pkl", ".txt")
    if "Linux" in platform.platform():
        data_root = '/mnt/4T/Data/data/UFSL/{}'.format(dataset_name)
        if not os.path.isdir(data_root):
            data_root = '/media/ubuntu/4T/ALISURE/Data/UFSL/{}'.format(
                dataset_name)
    else:
        data_root = "F:\\data\\{}".format(dataset_name)

    Tools.print(model_name, txt_path=log_file)
    Tools.print(data_root, txt_path=log_file)
    Tools.print(mn_dir, txt_path=log_file)
    pass
class Config(object):
    gpu_id = "0,1,2,3"
    gpu_num = len(gpu_id.split(","))
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

    num_workers = 32

    #######################################################################################
    ic_ratio = 1
    ic_knn = 100
    ic_out_dim = 2048
    ic_val_freq = 10

    # ic_resnet, ic_modify_head, ic_net_name = resnet18, False, "res18"
    ic_resnet, ic_modify_head, ic_net_name = resnet34, True, "res34_head"

    ic_learning_rate = 0.01
    ic_train_epoch = 1200
    ic_first_epoch, ic_t_epoch = 400, 200
    ic_batch_size = 64 * 4 * gpu_num

    ic_adjust_learning_rate = RunnerTool.adjust_learning_rate1
    #######################################################################################

    ###############################################################################################
    fsl_num_way = 5
    fsl_num_shot = 1

    fsl_episode_size = 15
    fsl_test_episode = 600

    # fsl_matching_net, fsl_net_name, fsl_batch_size = MatchingNet(hid_dim=64, z_dim=64), "conv4", 96
    fsl_matching_net, fsl_net_name, fsl_batch_size = ResNet12Small(
        avg_pool=True, drop_rate=0.1), "resnet12", 32

    fsl_learning_rate = 0.01
    fsl_batch_size = fsl_batch_size * gpu_num

    # fsl_val_freq = 5
    # fsl_train_epoch = 200
    # fsl_lr_schedule = [100, 150]

    # fsl_val_freq = 2
    # fsl_train_epoch = 100
    # fsl_lr_schedule = [50, 80]
    fsl_val_freq = 2
    fsl_train_epoch = 50
    fsl_lr_schedule = [30, 40]
    ###############################################################################################

    model_name = "{}_{}_{}_{}_{}_{}_{}_{}_{}_{}".format(
        gpu_id.replace(",", ""), ic_net_name, ic_train_epoch, ic_batch_size,
        ic_out_dim, fsl_net_name, fsl_train_epoch, fsl_num_way, fsl_num_shot,
        fsl_batch_size)

    if "Linux" in platform.platform():
        data_root = '/mnt/4T/Data/data/UFSL/tiered-imagenet'
        if not os.path.isdir(data_root):
            data_root = '/media/ubuntu/4T/ALISURE/Data/UFSL/tiered-imagenet'
    else:
        data_root = "F:\\data\\tiered-imagenet"

    ###############################################################################################
    # ic_batch_size = 16
    # fsl_batch_size = 16
    # ic_train_epoch = 2
    # ic_first_epoch, ic_t_epoch = 1, 1
    # ic_val_freq = 1
    # fsl_train_epoch = 8
    # fsl_lr_schedule = [4, 6]
    # fsl_test_episode = 20
    # fsl_val_freq = 2
    # data_root = os.path.join(data_root, "small")
    ###############################################################################################

    _root_path = "../tiered_imagenet/models_mn/two_ic_ufsl_2net_res_sgd_acc_duli_nete"
    mn_dir = Tools.new_dir("{}/{}_mn.pkl".format(_root_path, model_name))
    ic_dir = Tools.new_dir("{}/{}_ic.pkl".format(_root_path, model_name))
    # ic_dir_checkpoint = None
    # ic_dir_checkpoint = "../tiered_imagenet/models/ic_res_xx/3_resnet_18_64_2048_1_1900_300_200_False_ic.pkl"
    ic_dir_checkpoint = "../tiered_imagenet/models_mn/two_ic_ufsl_2net_res_sgd_acc_duli_nete/123_res34_head_1200_384_2048_conv4_100_5_1_288_ic.pkl"

    Tools.print(model_name)
    Tools.print(data_root)
    pass