예제 #1
0
파일: print_geno.py 프로젝트: 2BH/NAS_K49
def main():
    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    print("args = %s", args)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    model = Network(args.init_channels, args.input_channels, num_classes,
                    args.layers, criterion)

    model = model.cuda()
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    model.load_state_dict(torch.load(log_path + '/weights.pt'))

    print(model.genotype())
예제 #2
0
파일: search.py 프로젝트: m-pana/AutoSpeech
def main():
    args = parse_args()
    update_config(cfg, args)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    # Set the random seed manually for reproducibility.
    np.random.seed(cfg.SEED)
    torch.manual_seed(cfg.SEED)
    torch.cuda.manual_seed_all(cfg.SEED)

    # Loss
    criterion = CrossEntropyLoss(cfg.MODEL.NUM_CLASSES).cuda()

    # model and optimizer
    print(f"Definining network with {cfg.MODEL.LAYERS} layers...")
    model = Network(cfg.MODEL.INIT_CHANNELS, cfg.MODEL.NUM_CLASSES, cfg.MODEL.LAYERS, criterion, primitives_2,
                    drop_path_prob=cfg.TRAIN.DROPPATH_PROB)
    model = model.cuda()

    # weight params
    arch_params = list(map(id, model.arch_parameters()))
    weight_params = filter(lambda p: id(p) not in arch_params,
                           model.parameters())

    # Optimizer
    optimizer = optim.Adam(
        weight_params,
        lr=cfg.TRAIN.LR
    )

    # resume && make log dir and logger
    if args.load_path and os.path.exists(args.load_path):
        checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint_best.pth')
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file)

        # load checkpoint
        begin_epoch = checkpoint['epoch']
        last_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        best_acc1 = checkpoint['best_acc1']
        optimizer.load_state_dict(checkpoint['optimizer'])
        args.path_helper = checkpoint['path_helper']

        logger = create_logger(args.path_helper['log_path'])
        logger.info("=> loaded checkpoint '{}'".format(checkpoint_file))
    else:
        exp_name = args.cfg.split('/')[-1].split('.')[0]
        args.path_helper = set_path('logs_search', exp_name)
        logger = create_logger(args.path_helper['log_path'])
        begin_epoch = cfg.TRAIN.BEGIN_EPOCH
        best_acc1 = 0.0
        last_epoch = -1

    logger.info(args)
    logger.info(cfg)

    # copy model file
    this_dir = os.path.dirname(__file__)
    shutil.copy2(
        os.path.join(this_dir, 'models', cfg.MODEL.NAME + '.py'),
        args.path_helper['ckpt_path'])

    # Datasets and dataloaders

    # The toy dataset is downloaded with 10 items for each partition. Remove the sample_size parameters to use the full toy dataset
    asv_train, asv_dev, asv_eval = asv_toys(sample_size=10)


    train_dataset = asv_train #MNIST('mydata', transform=totensor, train=True, download=True)
    val_dataset = asv_dev #MNIST('mydata', transform=totensor, train=False, download=True)
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=cfg.TRAIN.BATCH_SIZE,
        num_workers=cfg.DATASET.NUM_WORKERS,
        pin_memory=True,
        shuffle=True,
        drop_last=True,
    )
    print(f'search.py: Train loader of {len(train_loader)} batches')
    print(f'Tot train set: {len(train_dataset)}')
    val_loader = torch.utils.data.DataLoader(
        dataset=val_dataset,
        batch_size=cfg.TRAIN.BATCH_SIZE,
        num_workers=cfg.DATASET.NUM_WORKERS,
        pin_memory=True,
        shuffle=True,
        drop_last=True,
    )
    print(f'search.py: Val loader of {len(val_loader)} batches')
    print(f'Tot val set {len(val_dataset)}')
    test_dataset = asv_eval #MNIST('mydata', transform=totensor, train=False, download=True)
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=cfg.TRAIN.BATCH_SIZE,
        num_workers=cfg.DATASET.NUM_WORKERS,
        pin_memory=True,
        shuffle=True,
        drop_last=True,
    )

    # training setting
    writer_dict = {
        'writer': SummaryWriter(args.path_helper['log_path']),
        'train_global_steps': begin_epoch * len(train_loader),
        'valid_global_steps': begin_epoch // cfg.VAL_FREQ,
    }

    # training loop
    architect = Architect(model, cfg)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, cfg.TRAIN.END_EPOCH, cfg.TRAIN.LR_MIN,
        last_epoch=last_epoch
    )

    for epoch in tqdm(range(begin_epoch, cfg.TRAIN.END_EPOCH), desc='search progress'):
        model.train()

        genotype = model.genotype()
        logger.info('genotype = %s', genotype)

        if cfg.TRAIN.DROPPATH_PROB != 0:
            model.drop_path_prob = cfg.TRAIN.DROPPATH_PROB * epoch / (cfg.TRAIN.END_EPOCH - 1)

        train(cfg, model, optimizer, train_loader, val_loader, criterion, architect, epoch, writer_dict)

        if epoch % cfg.VAL_FREQ == 0:
            # get threshold and evaluate on validation set
            acc = validate_identification(cfg, model, test_loader, criterion)

            # remember best acc@1 and save checkpoint
            is_best = acc > best_acc1
            best_acc1 = max(acc, best_acc1)

            # save
            logger.info('=> saving checkpoint to {}'.format(args.path_helper['ckpt_path']))
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
                'arch': model.arch_parameters(),
                'genotype': genotype,
                'path_helper': args.path_helper
            }, is_best, args.path_helper['ckpt_path'], 'checkpoint_{}.pth'.format(epoch))

        lr_scheduler.step(epoch)