コード例 #1
0
ファイル: train_search.py プロジェクト: qingquansong/darts
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    if args.loss_func == 'cce':
        criterion = nn.CrossEntropyLoss().cuda()
    elif args.loss_func == 'rll':
        criterion = utils.RobustLogLoss().cuda()
    else:
        assert False, "Invalid loss function '{}' given. Must be in {'cce', 'rll'}".format(
            args.loss_func)

    model = Network(args.init_channels, CIFAR_CLASSES, args.layers, criterion)
    model = model.cuda()
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    train_transform, valid_transform = utils._data_transforms_cifar10(args)

    # Load dataset
    if args.gold_fraction == 0:
        train_data = CIFAR10(root=args.data,
                             train=True,
                             gold=False,
                             gold_fraction=args.gold_fraction,
                             corruption_prob=args.corruption_prob,
                             corruption_type=args.corruption_type,
                             transform=train_transform,
                             download=True,
                             seed=args.seed)
        if args.clean_valid:
            gold_train_data = CIFAR10(root=args.data,
                                      train=True,
                                      gold=True,
                                      gold_fraction=1.0,
                                      corruption_prob=args.corruption_prob,
                                      corruption_type=args.corruption_type,
                                      transform=train_transform,
                                      download=True,
                                      seed=args.seed)
    else:
        train_data = CIFAR10(root=args.data,
                             train=True,
                             gold=True,
                             gold_fraction=args.gold_fraction,
                             corruption_prob=args.corruption_prob,
                             corruption_type=args.corruption_type,
                             transform=train_transform,
                             download=True,
                             seed=args.seed)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))

    train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=2)

    if args.clean_valid:
        valid_queue = torch.utils.data.DataLoader(
            gold_train_data,
            batch_size=args.batch_size,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(
                indices[split:]),
            pin_memory=True,
            num_workers=2)
    else:
        valid_queue = torch.utils.data.DataLoader(
            train_data,
            batch_size=args.batch_size,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(
                indices[split:]),
            pin_memory=True,
            num_workers=2)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs), eta_min=args.learning_rate_min)

    architect = Architect(model, args)

    for epoch in range(args.epochs):
        scheduler.step()
        lr = scheduler.get_lr()[0]
        logging.info('epoch %d lr %e', epoch, lr)

        genotype = model.genotype()
        logging.info('genotype = %s', genotype)

        print(F.softmax(model.alphas_normal, dim=-1))
        print(F.softmax(model.alphas_reduce, dim=-1))

        # training
        train_acc, train_obj = train(train_queue, valid_queue, model,
                                     architect, criterion, optimizer, lr)
        logging.info('train_acc %f', train_acc)

        # validation
        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        logging.info('valid_acc %f', valid_acc)

        utils.save(model, os.path.join(args.save, 'weights.pt'))
コード例 #2
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = False
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    cudnn.deterministic = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    if args.loss_func == 'cce':
        criterion = nn.CrossEntropyLoss().cuda()
    elif args.loss_func == 'rll':
        criterion = utils.RobustLogLoss(alpha=args.alpha).cuda()
    else:
        assert False, "Invalid loss function '{}' given. Must be in {'cce', 'rll'}".format(
            args.loss_func)

    model = Network(args.init_channels, CIFAR_CLASSES, args.layers, criterion)
    model = model.cuda()
    model.train()
    model.apply(weights_init)
    nn.utils.clip_grad_norm(model.parameters(), args.grad_clip)

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    train_transform, valid_transform = utils._data_transforms_cifar10(args)

    # Load dataset
    if args.dataset == 'cifar10':
        train_data = CIFAR10(root=args.data,
                             train=True,
                             gold=False,
                             gold_fraction=0.0,
                             corruption_prob=args.corruption_prob,
                             corruption_type=args.corruption_type,
                             transform=train_transform,
                             download=True,
                             seed=args.seed)
        gold_train_data = CIFAR10(root=args.data,
                                  train=True,
                                  gold=True,
                                  gold_fraction=1.0,
                                  corruption_prob=args.corruption_prob,
                                  corruption_type=args.corruption_type,
                                  transform=train_transform,
                                  download=True,
                                  seed=args.seed)
    elif args.dataset == 'cifar100':
        train_data = CIFAR100(root=args.data,
                              train=True,
                              gold=False,
                              gold_fraction=0.0,
                              corruption_prob=args.corruption_prob,
                              corruption_type=args.corruption_type,
                              transform=train_transform,
                              download=True,
                              seed=args.seed)
        gold_train_data = CIFAR100(root=args.data,
                                   train=True,
                                   gold=True,
                                   gold_fraction=1.0,
                                   corruption_prob=args.corruption_prob,
                                   corruption_type=args.corruption_type,
                                   transform=train_transform,
                                   download=True,
                                   seed=args.seed)
    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))

    clean_train_queue = torch.utils.data.DataLoader(
        gold_train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=0)
    noisy_train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=0)

    clean_valid_queue = torch.utils.data.DataLoader(
        gold_train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]),
        pin_memory=True,
        num_workers=0)
    noisy_valid_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]),
        pin_memory=True,
        num_workers=0)

    clean_train_list, clean_valid_list, noisy_train_list, noisy_valid_list = [], [], [], []
    for dst_list, queue in [
        (clean_train_list, clean_train_queue),
        (clean_valid_list, clean_valid_queue),
        (noisy_train_list, noisy_train_queue),
        (noisy_valid_list, noisy_valid_queue),
    ]:
        for input, target in queue:
            input = Variable(input, volatile=True).cuda()
            target = Variable(target, volatile=True).cuda(async=True)
            dst_list.append((input, target))

    for epoch in range(args.epochs):
        logging.info('Epoch %d, random architecture with fix weights', epoch)

        genotype = model.genotype()
        logging.info('genotype = %s', genotype)

        logging.info(F.softmax(model.alphas_normal, dim=-1))
        logging.info(F.softmax(model.alphas_reduce, dim=-1))

        # training
        clean_train_acc, clean_train_obj = infer(clean_train_list,
                                                 model,
                                                 criterion,
                                                 kind='clean_train')
        logging.info('clean_train_acc %f, clean_train_loss %f',
                     clean_train_acc, clean_train_obj)

        noisy_train_acc, noisy_train_obj = infer(noisy_train_list,
                                                 model,
                                                 criterion,
                                                 kind='noisy_train')
        logging.info('noisy_train_acc %f, noisy_train_loss %f',
                     noisy_train_acc, noisy_train_obj)

        # validation
        clean_valid_acc, clean_valid_obj = infer(clean_valid_list,
                                                 model,
                                                 criterion,
                                                 kind='clean_valid')
        logging.info('clean_valid_acc %f, clean_valid_loss %f',
                     clean_valid_acc, clean_valid_obj)

        noisy_valid_acc, noisy_valid_obj = infer(noisy_valid_list,
                                                 model,
                                                 criterion,
                                                 kind='noisy_valid')
        logging.info('noisy_valid_acc %f, noisy_valid_loss %f',
                     noisy_valid_acc, noisy_valid_obj)

        utils.save(model, os.path.join(args.save, 'weights.pt'))

        # Randomly change the alphas
        k = sum(1 for i in range(model._steps) for n in range(2 + i))
        num_ops = len(PRIMITIVES)
        model.alphas_normal.data.copy_(torch.randn(k, num_ops))
        model.alphas_reduce.data.copy_(torch.randn(k, num_ops))
コード例 #3
0
ファイル: train.py プロジェクト: YIWEI-CHEN/darts
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    random.seed(args.seed)
    np.random.seed(
        args.data_seed)  # cutout and load_corrupted_data use np.random
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = False
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    cudnn.deterministic = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    if args.arch == 'resnet':
        model = ResNet18(CIFAR_CLASSES).cuda()
        args.auxiliary = False
    elif args.arch == 'resnet50':
        model = ResNet50(CIFAR_CLASSES).cuda()
        args.auxiliary = False
    elif args.arch == 'resnet34':
        model = ResNet34(CIFAR_CLASSES).cuda()
        args.auxiliary = False
    else:
        genotype = eval("genotypes.%s" % args.arch)
        model = Network(args.init_channels, CIFAR_CLASSES, args.layers,
                        args.auxiliary, genotype)
        model = model.cuda()

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    train_transform, test_transform = utils._data_transforms_cifar10(args)

    # Load dataset
    if args.dataset == 'cifar10':
        noisy_train_data = CIFAR10(root=args.data,
                                   train=True,
                                   gold=False,
                                   gold_fraction=0.0,
                                   corruption_prob=args.corruption_prob,
                                   corruption_type=args.corruption_type,
                                   transform=train_transform,
                                   download=True,
                                   seed=args.data_seed)
        gold_train_data = CIFAR10(root=args.data,
                                  train=True,
                                  gold=True,
                                  gold_fraction=1.0,
                                  corruption_prob=args.corruption_prob,
                                  corruption_type=args.corruption_type,
                                  transform=train_transform,
                                  download=True,
                                  seed=args.data_seed)
        test_data = dset.CIFAR10(root=args.data,
                                 train=False,
                                 download=True,
                                 transform=test_transform)
    elif args.dataset == 'cifar100':
        noisy_train_data = CIFAR100(root=args.data,
                                    train=True,
                                    gold=False,
                                    gold_fraction=0.0,
                                    corruption_prob=args.corruption_prob,
                                    corruption_type=args.corruption_type,
                                    transform=train_transform,
                                    download=True,
                                    seed=args.data_seed)
        gold_train_data = CIFAR100(root=args.data,
                                   train=True,
                                   gold=True,
                                   gold_fraction=1.0,
                                   corruption_prob=args.corruption_prob,
                                   corruption_type=args.corruption_type,
                                   transform=train_transform,
                                   download=True,
                                   seed=args.data_seed)
        test_data = dset.CIFAR100(root=args.data,
                                  train=False,
                                  download=True,
                                  transform=test_transform)

    num_train = len(gold_train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))

    if args.gold_fraction == 1.0:
        train_data = gold_train_data
    else:
        train_data = noisy_train_data
    train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=0)

    if args.clean_valid:
        valid_data = gold_train_data
    else:
        valid_data = noisy_train_data

    valid_queue = torch.utils.data.DataLoader(
        valid_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]),
        pin_memory=True,
        num_workers=0)

    test_queue = torch.utils.data.DataLoader(test_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=2)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs))

    if args.loss_func == 'cce':
        criterion = nn.CrossEntropyLoss().cuda()
    elif args.loss_func == 'rll':
        criterion = utils.RobustLogLoss(alpha=args.alpha).cuda()
    elif args.loss_func == 'forward_gold':
        corruption_matrix = train_data.corruption_matrix
        criterion = utils.ForwardGoldLoss(corruption_matrix=corruption_matrix)
    else:
        assert False, "Invalid loss function '{}' given. Must be in {'cce', 'rll'}".format(
            args.loss_func)

    for epoch in range(args.epochs):
        scheduler.step()
        logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
        model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

        train_acc, train_obj = train(train_queue, model, criterion, optimizer)
        logging.info('train_acc %f', train_acc)

        valid_acc, valid_obj = infer_valid(valid_queue, model, criterion)
        logging.info('valid_acc %f', valid_acc)

        test_acc, test_obj = infer(test_queue, model, criterion)
        logging.info('test_acc %f', test_acc)

        utils.save(model, os.path.join(args.save, 'weights.pt'))
コード例 #4
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)
    cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    random.seed(args.seed)

    if args.arch == 'resnet':
        model = ResNet18(CIFAR_CLASSES).cuda()
        args.auxiliary = False
    elif args.arch == 'resnet50':
        model = ResNet50(CIFAR_CLASSES).cuda()
        args.auxiliary = False
    elif args.arch == 'resnet34':
        model = ResNet34(CIFAR_CLASSES).cuda()
        args.auxiliary = False
    else:
        genotype = eval("genotypes.%s" % args.arch)
        model = Network(args.init_channels, CIFAR_CLASSES, args.layers,
                        args.auxiliary, genotype)
        model = model.cuda()

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    train_transform, test_transform = utils._data_transforms_cifar10(args)

    train_data = CIFAR10(root=args.data,
                         train=True,
                         gold=False,
                         gold_fraction=0.0,
                         corruption_prob=args.corruption_prob,
                         corruption_type=args.corruption_type,
                         transform=train_transform,
                         download=True,
                         seed=args.seed)
    gold_train_data = CIFAR10(root=args.data,
                              train=True,
                              gold=True,
                              gold_fraction=1.0,
                              corruption_prob=args.corruption_prob,
                              corruption_type=args.corruption_type,
                              transform=train_transform,
                              download=True,
                              seed=args.seed)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))

    clean_train_queue = torch.utils.data.DataLoader(
        gold_train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=2)

    noisy_train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=2)

    # clean_train_list = []
    # for input, target in clean_train_queue:
    #   input = Variable(input).cuda()
    #   target = Variable(target).cuda(async=True)
    #   clean_train_list.append((input, target))
    #
    # noisy_train_list = []
    # for input, target in noisy_train_queue:
    #   input = Variable(input).cuda()
    #   target = Variable(target).cuda(async=True)
    #   noisy_train_list.append((input, target))

    clean_valid_queue = torch.utils.data.DataLoader(
        gold_train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]),
        pin_memory=True,
        num_workers=2)

    noisy_valid_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]),
        pin_memory=True,
        num_workers=2)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs))

    if args.loss_func == 'cce':
        criterion = nn.CrossEntropyLoss().cuda()
    elif args.loss_func == 'rll':
        criterion = utils.RobustLogLoss(alpha=args.alpha).cuda()
    elif args.loss_func == 'forward_gold':
        corruption_matrix = train_data.corruption_matrix
        criterion = utils.ForwardGoldLoss(corruption_matrix=corruption_matrix)
    else:
        assert False, "Invalid loss function '{}' given. Must be in {'cce', 'rll'}".format(
            args.loss_func)

    global gain
    for epoch in range(args.epochs):
        if args.random_weight:
            logging.info('Epoch %d, Randomly assign weights', epoch)
            model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
            clean_obj, noisy_obj = infer_random_weight(clean_train_queue,
                                                       noisy_train_queue,
                                                       model, criterion)
            logging.info('clean loss %f, noisy loss %f', clean_obj, noisy_obj)
            gain = np.random.randint(1, args.grad_clip, size=1)[0]
        else:
            scheduler.step()
            logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
            model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

            train_acc, train_obj, another_obj = train(clean_train_queue,
                                                      noisy_train_queue, model,
                                                      criterion, optimizer)
            if args.clean_train:
                logging.info('train_acc %f, clean_loss %f, noisy_loss %f',
                             train_acc, train_obj, another_obj)
            else:
                logging.info('train_acc %f, clean_loss %f, noisy_loss %f',
                             train_acc, another_obj, train_obj)

            utils.save(model, os.path.join(args.save, 'weights.pt'))

        clean_valid_acc, clean_valid_obj = infer_valid(clean_valid_queue,
                                                       model, criterion)
        logging.info('clean_valid_acc %f, clean_valid_loss %f',
                     clean_valid_acc, clean_valid_obj)

        noisy_valid_acc, noisy_valid_obj = infer_valid(noisy_valid_queue,
                                                       model, criterion)
        logging.info('noisy_valid_acc %f, noisy_valid_loss %f',
                     noisy_valid_acc, noisy_valid_obj)
コード例 #5
0
ファイル: train.py プロジェクト: qingquansong/darts
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    genotype = eval("genotypes.%s" % args.arch)
    model = Network(args.init_channels, CIFAR_CLASSES, args.layers,
                    args.auxiliary, genotype)
    model = model.cuda()

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    if args.loss_func == 'cce':
        criterion = nn.CrossEntropyLoss().cuda()
    elif args.loss_func == 'rll':
        criterion = utils.RobustLogLoss(alpha=args.alpha).cuda()
    else:
        assert False, "Invalid loss function '{}' given. Must be in {'cce', 'rll'}".format(
            args.loss_func)
    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    train_transform, valid_transform = utils._data_transforms_cifar10(args)

    # valid_data is the test data here
    if args.dataset == 'cifar10':
        if args.gold_fraction == 0:
            train_data = CIFAR10(root=args.data,
                                 train=True,
                                 gold=False,
                                 gold_fraction=args.gold_fraction,
                                 corruption_prob=args.corruption_prob,
                                 corruption_type=args.corruption_type,
                                 transform=train_transform,
                                 download=True,
                                 seed=args.seed)
        else:
            train_data = CIFAR10(root=args.data,
                                 train=True,
                                 gold=True,
                                 gold_fraction=args.gold_fraction,
                                 corruption_prob=args.corruption_prob,
                                 corruption_type=args.corruption_type,
                                 transform=train_transform,
                                 download=True,
                                 seed=args.seed)
        valid_data = dset.CIFAR10(root=args.data,
                                  train=False,
                                  download=True,
                                  transform=valid_transform)

    elif args.dataset == 'cifar100':
        if args.gold_fraction == 0:
            train_data = CIFAR100(root=args.data,
                                  train=True,
                                  gold=False,
                                  gold_fraction=args.gold_fraction,
                                  corruption_prob=args.corruption_prob,
                                  corruption_type=args.corruption_type,
                                  transform=train_transform,
                                  download=True,
                                  seed=args.seed)
        else:
            train_data = CIFAR100(root=args.data,
                                  train=True,
                                  gold=True,
                                  gold_fraction=args.gold_fraction,
                                  corruption_prob=args.corruption_prob,
                                  corruption_type=args.corruption_type,
                                  transform=train_transform,
                                  download=True,
                                  seed=args.seed)
        valid_data = dset.CIFAR100(root=args.data,
                                   train=False,
                                   download=True,
                                   transform=valid_transform)

    train_queue = torch.utils.data.DataLoader(train_data,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              pin_memory=True,
                                              num_workers=2)

    valid_queue = torch.utils.data.DataLoader(valid_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=2)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs))

    for epoch in range(args.epochs):
        scheduler.step()
        logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
        model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

        train_acc, train_obj = train(train_queue, model, criterion, optimizer)
        logging.info('train_acc %f', train_acc)

        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        logging.info('valid_acc %f', valid_acc)

        utils.save(model, os.path.join(args.save, 'weights.pt'))
コード例 #6
0
def main():
    np.random.seed(args.seed)
    cudnn.benchmark = True
    cudnn.enabled = True
    torch.manual_seed(args.seed)

    # ================================================
    for id in device_ids:
        total, used = os.popen(
            'nvidia-smi --query-gpu=memory.total,memory.used --format=csv,nounits,noheader'
                ).read().split('\n')[id].split(',')
        print('GPU ({}) mem:'.format(id), total, 'used:', used)


    # try:
    #     block_mem = 0.85 * (total - used)
    #     print(block_mem)
    #     x = torch.empty((256, 1024, int(block_mem))).cuda()
    #     del x
    # except RuntimeError as err:
    #     print(err)
    #     block_mem = 0.8 * (total - used)
    #     print(block_mem)
    #     x = torch.empty((256, 1024, int(block_mem))).cuda()
    #     del x
    #
    #
    # print('reuse mem now ...')
    # ================================================

    args.unrolled = True


    logging.info('GPU device = %s' % args.gpu)
    logging.info("args = %s", args)

    train_transform, valid_transform = utils._data_transforms_cifar10(args)

    # Load dataset
    if args.dataset == 'cifar10':
        if args.gold_fraction == 0:
            train_data = CIFAR10(
                root=args.data, train=True, gold=False, gold_fraction=args.gold_fraction,
                corruption_prob=args.corruption_prob, corruption_type=args.corruption_type,
                transform=train_transform, download=True, seed=args.seed)
            if args.clean_valid:
                gold_train_data = CIFAR10(
                    root=args.data, train=True, gold=True, gold_fraction=1.0,
                    corruption_prob=args.corruption_prob, corruption_type=args.corruption_type,
                    transform=train_transform, download=True, seed=args.seed)
        else:
            train_data = CIFAR10(
                root=args.data, train=True, gold=True, gold_fraction=args.gold_fraction,
                corruption_prob=args.corruption_prob, corruption_type=args.corruption_type,
                transform=train_transform, download=True, seed=args.seed)
        num_classes = 10

    elif args.dataset == 'cifar100':
        if args.gold_fraction == 0:
            train_data = CIFAR100(
                root=args.data, train=True, gold=False, gold_fraction=args.gold_fraction,
                corruption_prob=args.corruption_prob, corruption_type=args.corruption_type,
                transform=train_transform, download=True, seed=args.seed)
            if args.clean_valid:
                gold_train_data = CIFAR100(
                    root=args.data, train=True, gold=True, gold_fraction=1.0,
                    corruption_prob=args.corruption_prob, corruption_type=args.corruption_type,
                    transform=train_transform, download=True, seed=args.seed)
        else:
            train_data = CIFAR100(
                root=args.data, train=True, gold=True, gold_fraction=args.gold_fraction,
                corruption_prob=args.corruption_prob, corruption_type=args.corruption_type,
                transform=train_transform, download=True, seed=args.seed)
        num_classes = 100

    # Split data to train and validation
    num_train = len(train_data)  # 50000
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))  # 45000

    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=batchsz,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True, num_workers=2)

    if args.clean_valid:
        valid_queue = torch.utils.data.DataLoader(
            gold_train_data, batch_size=batchsz,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]),
            pin_memory=True, num_workers=2)
    else:
        valid_queue = torch.utils.data.DataLoader(
            train_data, batch_size=batchsz,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]),
            pin_memory=True, num_workers=2)

    if args.loss_func == 'cce':
        criterion = nn.CrossEntropyLoss().cuda()
    elif args.loss_func == 'rll':
        criterion = utils.RobustLogLoss().cuda()
    else:
        assert False, "Invalid loss function '{}' given. Must be in {'cce', 'rll'}".format(args.loss_func)
    model = Network(args.init_ch, num_classes, args.layers, criterion)
    if len(device_ids) > 1:
        model = MyDataParallel(model).cuda()
    else:
        model.cuda()
    # model = para_model.module.cuda()

    logging.info("Total param size = %f MB", utils.count_parameters_in_MB(model))

    # this is the optimizer to optimize
    optimizer = optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True)

    scheduler = optim.lr_scheduler.CosineAnnealingLR(
                    optimizer, float(args.epochs), eta_min=args.lr_min)

    arch = Arch(model, args)

    global start
    start = time.time()
    for epoch in range(args.epochs):
        current_time = time.time()
        if current_time - start >= args.time_limit:
            break
        scheduler.step()
        lr = scheduler.get_lr()[0]
        logging.info('\nEpoch: %d lr: %e', epoch, lr)

        genotype = model.genotype()
        logging.info('Genotype: %s', genotype)

        # print(F.softmax(model.alphas_normal, dim=-1))
        # print(F.softmax(model.alphas_reduce, dim=-1))

        # training
        train_acc, train_obj = train(train_queue, valid_queue, model, arch, criterion, optimizer, lr)
        logging.info('train acc: %f', train_acc)

        # validation
        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        logging.info('valid acc: %f', valid_acc)

        utils.save(model, os.path.join(args.exp_path, 'search_epoch1.pt'))