Ejemplo n.º 1
0
def main():
    if not torch.cuda.is_available():
        logging.info('No GPU device available')
        sys.exit(1)
    np.random.seed(args.seed)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info("args = %s", args)
    logging.info("unparsed args = %s", unparsed)

    # prepare dataset
    if args.is_cifar100:
        train_transform, valid_transform = utils._data_transforms_cifar100(
            args)
    else:
        train_transform, valid_transform = utils._data_transforms_cifar10(args)
    if args.is_cifar100:
        train_data = dset.CIFAR100(root=args.tmp_data_dir,
                                   train=True,
                                   download=False,
                                   transform=train_transform)
        valid_data = dset.CIFAR100(root=args.tmp_data_dir,
                                   train=False,
                                   download=False,
                                   transform=valid_transform)
    else:
        train_data = dset.CIFAR10(root=args.tmp_data_dir,
                                  train=True,
                                  download=False,
                                  transform=train_transform)
        valid_data = dset.CIFAR10(root=args.tmp_data_dir,
                                  train=False,
                                  download=False,
                                  transform=valid_transform)

    train_queue = torch.utils.data.DataLoader(train_data,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              pin_memory=True,
                                              num_workers=args.workers)

    valid_queue = torch.utils.data.DataLoader(valid_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=args.workers)

    # build Network
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    supernet = Network(args.init_channels, CIFAR_CLASSES, args.layers)
    supernet.cuda()

    if args.is_cifar100:
        weight_decay = 5e-4
    else:
        weight_decay = 3e-4
    optimizer = torch.optim.SGD(
        supernet.parameters(),
        args.learning_rate,
        momentum=args.momentum,
        weight_decay=weight_decay,
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs), eta_min=args.learning_rate_min)

    for epoch in range(args.epochs):
        logging.info('epoch %d lr %e', epoch, scheduler.get_last_lr()[0])

        train_acc, train_obj = train(train_queue, supernet, criterion,
                                     optimizer)
        logging.info('train_acc %f', train_acc)

        valid_top1 = utils.AverageMeter()
        for i in range(args.eval_time):
            supernet.generate_share_alphas()
            ops_alps = supernet.cells[0].ops_alphas
            subnet = supernet.get_sub_net(ops_alps)

            valid_acc, valid_obj = infer(valid_queue, subnet, criterion)
            valid_top1.update(valid_acc)
        logging.info('Mean Valid Acc: %f', valid_top1.avg)
        scheduler.step()

        utils.save(supernet, os.path.join(args.save, 'supernet_weights.pt'))
def main():
    if not torch.cuda.is_available():
        logging.info('No GPU device available')
        sys.exit(1)
    np.random.seed(args.super_seed)
    cudnn.benchmark = True
    torch.manual_seed(args.super_seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.super_seed)
    logging.info("args = %s", args)
    logging.info("unparsed args = %s", unparsed)

    # prepare dataset
    if args.cifar100:
        train_transform, valid_transform = utils._data_transforms_cifar100(args)
    else:
        train_transform, valid_transform = utils._data_transforms_cifar10(args)
    if args.cifar100:
        train_data = dset.CIFAR100(root=args.tmp_data_dir, train=True, download=True, transform=train_transform)
        valid_data = dset.CIFAR100(root=args.tmp_data_dir, train=False, download=True, transform=valid_transform)
    else:
        train_data = dset.CIFAR10(root=args.tmp_data_dir, train=True, download=True, transform=train_transform)
        valid_data = dset.CIFAR10(root=args.tmp_data_dir, train=False, download=True, transform=valid_transform)

    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers, drop_last=True)

    valid_queue = torch.utils.data.DataLoader(
        valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers, drop_last=True)

    ood_queues = {}
    for k in ['svhn', 'lsun_resized', 'imnet_resized']:
        ood_path = os.path.join(args.ood_dir, k)
        dset_ = dset.ImageFolder(ood_path, valid_transform)
        loader = torch.utils.data.DataLoader(
            dset_, batch_size=args.batch_size, shuffle=False,
            pin_memory=True, num_workers=args.workers
        )
        ood_queues[k] = loader

    # build Network
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    supernet = Network(
        args.init_channels, CIFAR_CLASSES, args.layers,
        combine_method=args.feat_comb, is_cosine=args.is_cosine,
    )
    supernet.cuda()
    supernet.generate_share_alphas()   #This is to prevent supernet alpha attribute being None type

    alphas_path = './results/{}/eval_out/{}/alphas.pt'.format(args.load_at.split('/')[2], args.folder)
    logging.info('Loading alphas at: %s' % alphas_path)
    alphas = torch.load(alphas_path)

    subnet = supernet.get_sub_net(alphas[:, :-1])
    logging.info(alphas)

    if args.cifar100:
        weight_decay = 5e-4
    else:
        weight_decay = 3e-4
    optimizer = torch.optim.SGD(
        subnet.parameters(),
        args.learning_rate,
        momentum=args.momentum,
        weight_decay=weight_decay,
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs), eta_min=args.learning_rate_min)

    for epoch in range(args.epochs):
        logging.info('epoch {} lr {:.4f}'.format(epoch, scheduler.get_last_lr()[0]))

        train_acc, _ = train(train_queue, subnet, criterion, optimizer)
        logging.info('train_acc {:.2f}'.format(train_acc))

        valid_acc, valid_loss = infer(valid_queue, subnet, criterion)
        writer_va.add_scalar('loss', valid_loss, global_step)
        writer_va.add_scalar('acc', valid_acc, global_step)
        logging.info('valid_acc {:.2f}'.format(valid_acc))
        scheduler.step()

    if not os.path.exists(args.ckpt_path):
        os.makedirs(args.ckpt_path)
    utils.save(subnet, os.path.join(args.ckpt_path, 'subnet_{}_weights.pt'.format(args.folder)))

    lg_aucs, sm_aucs, ent_aucs = ood_eval(valid_queue, ood_queues, subnet, criterion)

    logging.info('Writting results:')
    out_dir = './results/{}/eval_out/{}/'.format(args.load_at.split('/')[2], args.folder)
    with open(os.path.join(out_dir, 'subnet_scratch.txt'), 'w') as f:
        f.write('-'.join([str(valid_acc), str(lg_aucs), str(sm_aucs), str(ent_aucs)]))
Ejemplo n.º 3
0
def main():
    if not torch.cuda.is_available():
        print('No GPU device available')
        sys.exit(1)
    np.random.seed(args.seed)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    print("args = %s", args)
    print("unparsed args = %s", unparsed)

    # prepare dataset
    if args.cifar100:
        train_transform, valid_transform = utils._data_transforms_cifar100(args)
    else:
        train_transform, valid_transform = utils._data_transforms_cifar10(args)
    if args.cifar100:
        train_data = dset.CIFAR100(root=args.tmp_data_dir, train=True, download=False, transform=train_transform)
        valid_data = dset.CIFAR100(root=args.tmp_data_dir, train=False, download=False, transform=valid_transform)
    else:
        train_data = dset.CIFAR10(root=args.tmp_data_dir, train=True, download=False, transform=train_transform)
        valid_data = dset.CIFAR10(root=args.tmp_data_dir, train=False, download=False, transform=valid_transform)

    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers)

    valid_queue = torch.utils.data.DataLoader(
        valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers)


    # build Network
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    supernet = Network(
        args.init_channels, CIFAR_CLASSES, args.layers
    )
    supernet.cuda()

    ckpt = torch.load(args.load_at)
    print(args.load_at)
    supernet.load_state_dict(ckpt)
    supernet.generate_share_alphas()

    alphas = supernet.cells[0].ops_alphas
    print(alphas)
    out_dir = args.save + '{}/eval_out/{}'.format(args.load_at.split('/')[2], args.seed)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    torch.save(alphas, os.path.join(out_dir, 'alphas.pt'))
    with open(os.path.join(out_dir, 'alphas.txt'), 'w') as f:
        for i in alphas.cpu().detach().numpy():
            for j in i:
                f.write('{:d}'.format(int(j)))
            f.write('\n')

    # Getting subnet according to sample alpha
    subnet = supernet.get_sub_net(alphas)

    init_valid_acc, _ = infer(valid_queue, subnet, criterion)
    print('Initial Valid Acc {:.2f}'.format(init_valid_acc))

    if args.fine_tune:
        if args.cifar100:
            weight_decay = 5e-4
        else:
            weight_decay = 3e-4

        # Fine tuning whole network:
        subnet = supernet.get_sub_net(alphas)
        optimizer = torch.optim.SGD(
            subnet.parameters(),
            args.finetune_lr,
            momentum=args.momentum,
            weight_decay=weight_decay,
        )

        for epoch in range(args.epochs):
            # scheduler.step()
            print('epoch {} lr {:.4f}'.format(epoch, args.finetune_lr))

            train_acc, _ = train(train_queue, subnet, criterion, optimizer)
            print('train_acc {:.2f}'.format(train_acc))

            whole_valid_acc, _ = infer(valid_queue, subnet, criterion)
            print('valid_acc after whole fine-tune {:.2f}'.format(whole_valid_acc))

            fly_whole_valid_acc, _ = infer(valid_queue, subnet, criterion, use_fly_bn=False)
            print('valid_acc after whole fine-tune {:.2f}'.format(fly_whole_valid_acc))

        # Fine-tuning only classifier:
        subnet = supernet.get_sub_net(alphas)
            # Freezing other weights except classifier:
        for name, param in subnet.named_parameters():
            if not 'classifier' in name:
                param.requires_grad_(requires_grad=False)

        optimizer = torch.optim.SGD(
            subnet.classifier.parameters(),
            args.finetune_lr,
            momentum=args.momentum,
            weight_decay=weight_decay,
        )

        for epoch in range(args.epochs):
            # scheduler.step()
            print('epoch {} lr {:.4f}'.format(epoch, args.finetune_lr))

            train_acc, _ = train(train_queue, subnet, criterion, optimizer)
            print('train_acc {:.2f}'.format(train_acc))

            part_valid_acc, _ = infer(valid_queue, subnet, criterion)
            print('valid_acc after fine-tuning classifier {:.2f}'.format(part_valid_acc))

            fly_part_valid_acc, _ = infer(valid_queue, subnet, criterion, use_fly_bn=False)
            print('valid_acc after fine-tuning classifier {:.2f}'.format(fly_part_valid_acc))

        with open(os.path.join(out_dir, 'results.txt'), 'w') as f:
            f.write('-'.join([str(init_valid_acc), str(whole_valid_acc),
                    str(fly_whole_valid_acc), str(part_valid_acc), str(fly_part_valid_acc)]))

    if not args.fine_tune:
        with open(os.path.join(out_dir, 'results.txt'), 'w') as f:
            f.write(str(init_valid_acc))