Esempio n. 1
0
def build_cifar100(model_state_dict, optimizer_state_dict, **kwargs):
    epoch = kwargs.pop('epoch')

    train_transform, valid_transform = utils._data_transforms_cifar10(
        args.cutout_size)
    train_data = dset.CIFAR100(root=args.data,
                               train=True,
                               download=True,
                               transform=train_transform)
    valid_data = dset.CIFAR100(root=args.data,
                               train=False,
                               download=True,
                               transform=valid_transform)

    train_queue = torch.utils.data.DataLoader(train_data,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              pin_memory=True,
                                              num_workers=16)
    valid_queue = torch.utils.data.DataLoader(valid_data,
                                              batch_size=args.eval_batch_size,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=16)

    model = NASNetworkCIFAR(args, 100, args.layers, args.nodes, args.channels,
                            args.keep_prob, args.drop_path_keep_prob,
                            args.use_aux_head, args.steps, args.arch)
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
    logging.info("multi adds = %fM", model.multi_adds / 1000000)
    if model_state_dict is not None:
        model.load_state_dict(model_state_dict)

    if torch.cuda.device_count() > 1:
        logging.info("Use %d %s", torch.cuda.device_count(), "GPUs !")
        model = nn.DataParallel(model)
    model = model.cuda()

    train_criterion = nn.CrossEntropyLoss().cuda()
    eval_criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(
        model.parameters(),
        args.lr_max,
        momentum=0.9,
        weight_decay=args.l2_reg,
    )

    if optimizer_state_dict is not None:
        optimizer.load_state_dict(optimizer_state_dict)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs), args.lr_min, epoch)
    return train_queue, valid_queue, model, train_criterion, eval_criterion, optimizer, scheduler
Esempio n. 2
0
def train_and_evaluate_top_on_cifar100(archs, train_queue, valid_queue):
    res = []
    train_criterion = nn.CrossEntropyLoss().cuda()
    eval_criterion = nn.CrossEntropyLoss().cuda()
    objs = utils.AvgrageMeter()
    top1 = utils.AvgrageMeter()
    top5 = utils.AvgrageMeter()
    for i, arch in enumerate(archs):
        objs.reset()
        top1.reset()
        top5.reset()
        logging.info('Train and evaluate the {} arch'.format(i + 1))
        model = NASNetworkCIFAR(args, 100, args.child_layers, args.child_nodes,
                                args.child_channels, 0.6, 0.8, True,
                                args.steps, arch)
        model = model.cuda()
        model.train()
        optimizer = torch.optim.SGD(
            model.parameters(),
            args.child_lr_max,
            momentum=0.9,
            weight_decay=args.child_l2_reg,
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, 10, args.child_lr_min)
        global_step = 0
        for e in range(10):
            scheduler.step()
            for step, (input, target) in enumerate(train_queue):
                input = input.cuda().requires_grad_()
                target = target.cuda()

                optimizer.zero_grad()
                # sample an arch to train
                logits, aux_logits = model(input, global_step)
                global_step += 1
                loss = train_criterion(logits, target)
                if aux_logits is not None:
                    aux_loss = train_criterion(aux_logits, target)
                    loss += 0.4 * aux_loss
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(),
                                         args.child_grad_bound)
                optimizer.step()

                prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
                n = input.size(0)
                objs.update(loss.data, n)
                top1.update(prec1.data, n)
                top5.update(prec5.data, n)

                if (step + 1) % 100 == 0:
                    logging.info('Train %3d %03d loss %e top1 %f top5 %f',
                                 e + 1, step + 1, objs.avg, top1.avg, top5.avg)
        objs.reset()
        top1.reset()
        top5.reset()
        with torch.no_grad():
            model.eval()
            for step, (input, target) in enumerate(valid_queue):
                input = input.cuda()
                target = target.cuda()

                logits, _ = model(input)
                loss = eval_criterion(logits, target)

                prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
                n = input.size(0)
                objs.update(loss.data, n)
                top1.update(prec1.data, n)
                top5.update(prec5.data, n)

                if (step + 1) % 100 == 0:
                    logging.info('valid %03d %e %f %f', step + 1, objs.avg,
                                 top1.avg, top5.avg)
        res.append(top1.avg)
    return res