示例#1
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    # Set random seeds.
    np.random.seed(args.seed)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('args = %s', args)

    # Get data loaders.
    train_queue, valid_queue, num_classes = datasets.get_loaders(
        args, 'search')

    # Set up the network and criterion.
    model = Network(num_classes=num_classes,
                    layers=args.layers,
                    dataset=args.dataset)
    model = model.cuda()
    alpha = Alpha(num_normal=model.layers - len(model.channel_change_layers),
                  num_reduce=len(model.channel_change_layers),
                  num_op_normal=len(BLOCK_PRIMITIVES),
                  num_op_reduce=len(REDUCTION_PRIMITIVES),
                  gsm_soften_eps=args.gsm_soften_eps,
                  gsm_temperature=args.gumbel_soft_temp,
                  gsm_type=args.gsm_type,
                  same_alpha_minibatch=args.same_alpha_minibatch)
    alpha = alpha.cuda()
    model = DDP(model, delay_allreduce=True)
    alpha = DDP(alpha, delay_allreduce=True)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    logging.info('param size = %fM ', utils.count_parameters_in_M(model))
    writer.add_scalar('temperature', args.gumbel_soft_temp)

    # Get weight params, and arch params.
    weight_params = [p for p in model.parameters()]
    arch_params = [p for p in alpha.parameters()]

    logging.info('#Weight params: %d, #Arch params: %d' %
                 (len(weight_params), len(arch_params)))

    # Initial weight pretraining.
    def run_train_init():
        logging.info('running init epochs.')
        opt = torch.optim.Adam(weight_params,
                               args.learning_rate,
                               weight_decay=args.weight_decay)
        for e in range(args.init_epochs):
            # Shuffle the sampler.
            train_queue.sampler.set_epoch(e + args.seed)
            train_acc, train_obj = train_init(train_queue, model, alpha,
                                              criterion, opt, weight_params)
            logging.info('train_init_acc %f', train_acc)
            valid_acc, valid_obj = infer(valid_queue, model, criterion)
            logging.info('valid_init_acc %f', valid_acc)
            memory_cached, memory_alloc = utils.get_memory_usage(device=0)
            logging.info('memory_cached %0.3f memory_alloc %0.3f' %
                         (memory_cached, memory_alloc))
            if args.local_rank == 0:  # if you are performing many searches, you can store the pretrained model.
                torch.save(model.module.state_dict(), args.pretrained_model)

    if args.init_epochs:
        # if you are performing many searches, you can store the pretrained model.
        if os.path.isfile(args.pretrained_model) and False:
            logging.info('loading pretrained model.')
            # load to cpu to avoid loading all params to GPU0
            param = torch.load(args.pretrained_model, map_location='cpu')
            d = torch.device("cuda")
            model.module.load_state_dict(param, strict=False)
            model.to(d)
        else:
            run_train_init()

    # Set up network weights optimizer.
    optimizer = torch.optim.Adam(weight_params,
                                 args.learning_rate,
                                 weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs), eta_min=args.learning_rate_min)

    # Wrap model in UNAS
    nas = UNAS(model, alpha, args, writer, logging)

    global_step = 0
    for epoch in range(args.epochs):
        # Shuffle the sampler, update lrs.
        train_queue.sampler.set_epoch(epoch + args.seed)
        scheduler.step()
        nas.arch_scheduler.step()

        # Logging.
        if args.local_rank == 0:
            memory_cached, memory_alloc = utils.get_memory_usage(device=0)
            writer.add_scalar('memory/cached', memory_cached, global_step)
            writer.add_scalar('memory/alloc', memory_alloc, global_step)
            logging.info('memory_cached %0.3f memory_alloc %0.3f' %
                         (memory_cached, memory_alloc))
            logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
            writer.add_scalar('train/lr', scheduler.get_lr()[0], global_step)
            writer.add_scalar('train/arc_lr',
                              nas.arch_scheduler.get_lr()[0], global_step)

            prob = F.softmax(alpha.module.alphas_normal, dim=-1)
            logging.info('alphas_normal:')
            logging.info(prob)
            fig = alpha.module.plot_alphas(BLOCK_PRIMITIVES, is_normal=True)
            writer.add_figure('weights/disp_normal', fig, global_step)

            prob = F.softmax(alpha.module.alphas_reduce, dim=-1)
            logging.info('alphas_reduce:')
            logging.info(prob)
            fig = alpha.module.plot_alphas(REDUCTION_PRIMITIVES,
                                           is_normal=False)
            writer.add_figure('weights/disp_reduce', fig, global_step)

        # Training.
        train_acc, train_obj, global_step = train(train_queue, valid_queue,
                                                  model, alpha, nas, criterion,
                                                  optimizer, global_step,
                                                  weight_params, args.seed)
        logging.info('train_acc %f', train_acc)
        writer.add_scalar('train/acc', train_acc, global_step)

        # Validation.
        valid_queue.sampler.set_epoch(0)
        valid_acc, valid_obj = infer(valid_queue, model, alpha, criterion)
        logging.info('valid_acc %f', valid_acc)
        writer.add_scalar('val/acc', valid_acc, global_step)
        writer.add_scalar('val/loss', valid_obj, global_step)

        if args.local_rank == 0:
            logging.info('Saving the model and genotype.')
            utils.save(model, os.path.join(args.save, 'weights.pt'))
            torch.save(
                alpha.module.genotype(BLOCK_PRIMITIVES, REDUCTION_PRIMITIVES),
                os.path.join(args.save, 'genotype.pt'))

    writer.flush()
示例#2
0
def main():
    # Scale learning rate based on global batch size.
    if not args.no_scale_lr:
        scale = float(args.batch_size * args.world_size) / 128.0
        args.learning_rate = scale * args.learning_rate

    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)
    np.random.seed(args.seed)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info('args = %s', args)

    # Get data loaders.
    traindir = os.path.join(args.data, 'train')
    validdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.4,
                               contrast=0.4,
                               saturation=0.4,
                               hue=0.2),
        transforms.ToTensor(),
        normalize,
    ])
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])

    if 'lmdb' in args.data:
        train_data = imagenet_lmdb_dataset(traindir, transform=train_transform)
        valid_data = imagenet_lmdb_dataset(validdir, transform=val_transform)
    else:
        train_data = dset.ImageFolder(traindir, transform=train_transform)
        valid_data = dset.ImageFolder(validdir, transform=val_transform)

    train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
    train_queue = torch.utils.data.DataLoader(train_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=8,
                                              sampler=train_sampler)

    valid_queue = torch.utils.data.DataLoader(valid_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=8)

    # Set up the network.
    if os.path.isfile(args.genotype):
        logging.info('Loading genotype from: %s' % args.genotype)
        genotype = torch.load(args.genotype, map_location='cpu')
    else:
        logging.info('Loading genotype: %s' % args.genotype)
        genotype = eval('genotypes.%s' % args.genotype)
    if not isinstance(genotype, list):
        genotype = [genotype]

    # If num channels not provided, find the max under 600M MAdds.
    if args.init_channels < 0:
        if args.local_rank == 0:
            flops, num_params, init_channels = find_max_channels(
                genotype, args.layers, args.max_M_flops * 1e6)
            logging.info('Num flops = %.2fM', flops / 1e6)
            logging.info('Num params = %.2fM', num_params / 1e6)
        else:
            init_channels = 0
        # All reduce with world_size 1 is sum.
        init_channels = torch.Tensor([init_channels]).cuda()
        init_channels = utils.reduce_tensor(init_channels, 1)
        args.init_channels = int(init_channels.item())
    logging.info('Num channels = %d', args.init_channels)

    # Create model and loss.
    model = Network(args.init_channels, CLASSES, args.layers, args.auxiliary,
                    genotype)
    model = model.cuda()
    model = DDP(model, delay_allreduce=True)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth)
    criterion_smooth = criterion_smooth.cuda()
    logging.info('param size = %fM', utils.count_parameters_in_M(model))

    # Set up network weights optimizer.
    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs), eta_min=args.min_learning_rate)

    # Train.
    global_step = 0
    best_acc_top1 = 0
    for epoch in range(args.epochs):
        # Shuffle the sampler, update lrs.
        train_queue.sampler.set_epoch(epoch + args.seed)
        # Change lr.
        if epoch >= args.warmup_epochs:
            scheduler.step()
        model.module.drop_path_prob = args.drop_path_prob * epoch / args.epochs

        # Training.
        train_acc, train_obj, global_step = train(train_queue, model,
                                                  criterion_smooth, optimizer,
                                                  epoch, args.learning_rate,
                                                  args.warmup_epochs,
                                                  global_step)
        logging.info('train_acc %f', train_acc)
        writer.add_scalar('train/acc', train_acc, global_step)

        # Validation.
        valid_acc_top1, valid_acc_top5, valid_obj = infer(
            valid_queue, model, criterion)
        logging.info('valid_acc_top1 %f', valid_acc_top1)
        logging.info('valid_acc_top5 %f', valid_acc_top5)
        writer.add_scalar('val/acc_top1', valid_acc_top1, global_step)
        writer.add_scalar('val/acc_top5', valid_acc_top5, global_step)
        writer.add_scalar('val/loss', valid_obj, global_step)

        is_best = False
        if valid_acc_top1 > best_acc_top1:
            best_acc_top1 = valid_acc_top1
            is_best = True

        if args.local_rank == 0:
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_acc_top1': best_acc_top1,
                    'optimizer': optimizer.state_dict(),
                }, is_best, args.save)
示例#3
0
def main():
    """Do everything!"""
    # Scale learning rate based on global batch size.
    if args.scale_lr:
        scale = float(args.batch_size * args.world_size) / 64.0
        args.learning_rate = scale * args.learning_rate
        args.arch_learning_rate = scale * args.arch_learning_rate

    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d', args.gpu)
    logging.info('args = %s', args)

    # Get data loaders.
    train_queue, valid_queue, num_classes = datasets.get_loaders(args, 'search')

    # Set up the network and criterion.
    model = Network(args.init_channels, num_classes, args.layers,
                    num_cell_types=args.num_cell_types,
                    dataset=args.dataset,
                    steps=args.steps,
                    multiplier=args.multiplier)
    model = model.cuda()
    alpha = Alpha(num_normal=1,
                  num_reduce=1,
                  num_op=len(PRIMITIVES),
                  num_nodes=args.steps,
                  gsm_soften_eps=args.gsm_soften_eps,
                  gsm_temperature=args.gumbel_soft_temp,
                  gsm_type=args.gsm_type,
                  same_alpha_minibatch=args.same_alpha_minibatch)
    alpha = alpha.cuda()
    model = DDP(model, delay_allreduce=True)
    alpha = DDP(alpha, delay_allreduce=True)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    logging.info('param size = %fM ', utils.count_parameters_in_M(model))
    writer.add_scalar('temperature', args.gumbel_soft_temp)

    # Get weight params, and arch params.
    weight_params = [p for p in model.parameters()]
    arch_params = [p for p in alpha.parameters()]

    logging.info('#Weight params: %d, #Arch params: %d' %
                 (len(weight_params), len(arch_params)))

    # Initial weight pretraining.
    def run_train_init():
        logging.info('running init epochs.')
        opt = torch.optim.Adam(
            weight_params,
            args.learning_rate,
            weight_decay=args.weight_decay)
        for e in range(args.init_epochs):
            # Shuffle the sampler.
            train_queue.sampler.set_epoch(e + args.seed)
            train_acc, train_obj = train_init(
                train_queue, model, alpha, criterion, opt, weight_params)
            logging.info('train_init_acc %f', train_acc)
            valid_acc, valid_obj = infer(valid_queue, model, alpha, criterion)
            logging.info('valid_init_acc %f', valid_acc)
            memory_cached, memory_alloc = utils.get_memory_usage(device=0)
            logging.info('memory_cached %0.3f memory_alloc %0.3f' % (memory_cached, memory_alloc))

    if args.init_epochs:
        run_train_init()

    # Set up network weights optimizer.
    optimizer = torch.optim.Adam(
        weight_params, args.learning_rate, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs), eta_min=args.learning_rate_min)

    # Wrap model in architecture learner.
    nas = UNAS(model, alpha, args, writer, logging)

    global_step = 0
    for epoch in range(args.epochs):
        # Shuffle the sampler, update lrs.
        train_queue.sampler.set_epoch(epoch + args.seed)
        scheduler.step()
        nas.arch_scheduler.step()

        # Logging.
        if args.local_rank == 0:
            memory_cached, memory_alloc = utils.get_memory_usage(device=0)
            writer.add_scalar('memory/cached', memory_cached, global_step)
            writer.add_scalar('memory/alloc', memory_alloc, global_step)
            logging.info('memory_cached %0.3f memory_alloc %0.3f' % (memory_cached, memory_alloc))
            logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
            writer.add_scalar('train/lr', scheduler.get_lr()[0], global_step)
            writer.add_scalar('train/arc_lr', nas.arch_scheduler.get_lr()[0], global_step)

            genotypes = alpha.module.genotype()
            logging.info('genotype:')
            for genotype in genotypes:
                logging.info('normal:')
                logging.info(genotype.normal)
                logging.info('reduce:')
                logging.info(genotype.reduce)

            # alphas_normal, alphas_reduce = weights.module.add_based_alpha_paired_input()
            for l in range(len(alpha.module.alphas_normal)):
                fig = alpha.module.plot_alphas(alpha.module.alphas_normal[l])
                writer.add_figure('weights/disp_normal_%d' % l, fig, global_step)
            for l in range(len(alpha.module.alphas_reduce)):
                fig = alpha.module.plot_alphas(alpha.module.alphas_reduce[l])
                writer.add_figure('weights/disp_reduce_%d' % l, fig, global_step)

        # Training.
        train_acc, train_obj, global_step = train(
            train_queue, valid_queue, model, alpha, nas, criterion,
            optimizer, global_step, weight_params, args.seed)
        logging.info('train_acc %f', train_acc)
        writer.add_scalar('train/acc', train_acc, global_step)

        # Validation.
        valid_queue.sampler.set_epoch(0)
        valid_acc, valid_obj = infer(valid_queue, model, alpha, criterion)
        logging.info('valid_acc %f', valid_acc)
        writer.add_scalar('val/acc', valid_acc, global_step)
        writer.add_scalar('val/loss', valid_obj, global_step)

        if args.local_rank == 0:
            utils.save(model, os.path.join(args.save, 'weights.pt'))
            torch.save(alpha.module.genotype(), os.path.join(args.save, 'genotype.pt'))

    writer.flush()