예제 #1
0
    parser.add_argument("--title",
                        type=str,
                        help="experiment title",
                        required=True)
    args = parser.parse_args()

    CONFIG = get_config(args.cfg)

    if CONFIG.cuda:
        device = torch.device("cuda" if (
            torch.cuda.is_available() and CONFIG.ngpu > 0) else "cpu")
    else:
        device = torch.device("cpu")

    get_logger(CONFIG.log_dir)
    writer = get_writer(args.title, CONFIG.write_dir)

    logging.info(
        "=================================== Experiment title : {} Start ==========================="
        .format(args.title))

    set_random_seed(CONFIG.seed)

    train_transform, val_transform, test_transform = get_transforms(CONFIG)
    train_dataset, val_dataset, test_dataset = get_dataset(
        train_transform, val_transform, test_transform, CONFIG)
    train_loader, val_loader, test_loader = get_dataloader(
        train_dataset, val_dataset, test_dataset, CONFIG)

    generator = get_generator(CONFIG, 21 * 8)
예제 #2
0
    parser.add_argument("--evaluate-arch-param", action="store_true", default=False, help="whether to evaluate arch_param")
    parser.add_argument("--evaluate-lookup_table", action="store_true", default=False, help="whether to evaluate lookup_table")
    parser.add_argument("--loading-architectures", action="store_true", default=False, help="whether to load the architecture")
    parser.add_argument("--generate-architecture-parameter", action="store_true", default=False, help="generate architecture")
    parser.add_argument("--target-macs", type=int, help="target macs")
    args = parser.parse_args()

    CONFIG = get_config(args.cfg)

    if CONFIG.cuda:
        device = torch.device("cuda" if (torch.cuda.is_available() and CONFIG.ngpu > 0) else "cpu")
    else:
        device = torch.device("cpu")

    get_logger(CONFIG.log_dir)
    writer = get_writer(CONFIG.write_dir)

    #set_random_seed(CONFIG.seed)

    train_transform, val_transform, test_transform = get_transforms(CONFIG)
    train_dataset, val_dataset, test_dataset = get_dataset(train_transform, val_transform, test_transform, CONFIG)
    train_loader, val_loader, test_loader = get_dataloader(train_dataset, val_dataset, test_dataset, CONFIG)

    model = Supernet(CONFIG)
    lookup_table = LookUpTable(CONFIG)

    arch_param_nums = model.get_arch_param_nums()
    #generator = ConvGenerator(CONFIG.hc_dim, 1, CONFIG.hidden_dim)
    generator = get_generator(CONFIG, arch_param_nums)

    criterion = cross_encropy_with_label_smoothing