示例#1
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    genotype = eval("genotypes.%s" % args.arch)
    if args.is_cifar100:
        model = Network(args.init_channels, CIFAR100_CLASSES, args.layers,
                        args.auxiliary, genotype)
    else:
        model = Network(args.init_channels, CIFAR_CLASSES, args.layers,
                        args.auxiliary, genotype)
    model = model.cuda()

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    if args.is_cifar100:
        train_transform, valid_transform = utils._data_transforms_cifar100(
            args)
    else:
        train_transform, valid_transform = utils._data_transforms_cifar10(args)
    if args.is_cifar100:
        train_data = dset.CIFAR100(root=args.data,
                                   train=True,
                                   download=True,
                                   transform=train_transform)
        valid_data = dset.CIFAR100(root=args.data,
                                   train=False,
                                   download=True,
                                   transform=valid_transform)
    else:
        train_data = dset.CIFAR10(root=args.data,
                                  train=True,
                                  download=True,
                                  transform=train_transform)
        valid_data = dset.CIFAR10(root=args.data,
                                  train=False,
                                  download=True,
                                  transform=valid_transform)

    train_queue = torch.utils.data.DataLoader(train_data,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              pin_memory=True,
                                              num_workers=2)

    valid_queue = torch.utils.data.DataLoader(valid_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=2)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs))
    start_epoch = 0
    if args.resume:
        checkpoint = torch.load(args.resume)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
    for epoch in range(start_epoch, args.epochs):
        logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
        model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

        train_acc, train_obj = train(train_queue, model, criterion, optimizer)
        scheduler.step()
        logging.info('train_acc %f', train_acc)

        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        logging.info('valid_acc %f', valid_acc)

        utils.save(model, os.path.join(args.save, 'weights.pt'))
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'scheduler': scheduler.state_dict(),
            'optimizer': optimizer.state_dict()
        })
示例#2
0
def main():
	if not torch.cuda.is_available():
		logging.info('no gpu device available')
		sys.exit(1)
	
	np.random.seed(args.seed)
	# torch.cuda.set_device(args.gpu)
	device = torch.device("cuda")
	cudnn.benchmark = True
	torch.manual_seed(args.seed)
	cudnn.enabled=True
	torch.cuda.manual_seed(args.seed)
	logging.info('gpu device = %d' % args.gpu)
	logging.info("args = %s", args)
	
	# read data
	train_transform, valid_transform = utils._data_transforms_cifar10(args)
	if args.dataset == 'cifar10':
		args.data = '/home/work/dataset/cifar'
		train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
		valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)
		classes = 10
	if args.dataset == 'cifar100':
		args.data = '/home/work/dataset/cifar100'
		train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform)
		valid_data = dset.CIFAR100(root=args.data, train=False, download=True, transform=valid_transform)
		classes = 100
	train_queue = torch.utils.data.DataLoader(
		train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=2)
	valid_queue = torch.utils.data.DataLoader(
		valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=2)
	
	# model
	genotype = eval("genotypes.%s" % args.arch)
	model = Network(args.init_channels, classes, args.layers, args.auxiliary, genotype)
	model = model.cuda()
	model.drop_path_prob = args.drop_path_prob
	
	flops, params = profile(model, inputs=(torch.randn(1, 3, 32, 32).cuda(),), verbose=False)
	logging.info('flops = %fM', flops / 1e6)
	logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
	
	criterion = nn.CrossEntropyLoss()
	criterion = criterion.cuda()
	optimizer = torch.optim.SGD(
		model.parameters(),
		args.learning_rate,
		momentum=args.momentum,
		weight_decay=args.weight_decay
	)
	
	scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs))
	best_val_acc = 0.
	
	if args.resume:
		# state = torch.load('/home/work/lixudong/code_work/sgas/cnn/full_train_s3_1-20200608/weights.pt')
		# state = torch.load('/home/work/lixudong/code_work/sgas/cnn/full_train_s2_factor1-20200609/weights.pt', map_location='cpu')
		# state = torch.load('/home/work/lixudong/code_work/sgas/cnn/full_train_s3_factor1-20200609/weights.pt', map_location='cpu')
		# state = torch.load('/home/work/lixudong/code_work/sgas/cnn/full_train_s3_0-20200608/weights.pt', map_location='cpu')
		# state = torch.load('/home/work/lixudong/code_work/sgas/cnn/full_train_s2_0-20200608/weights.pt', map_location='cpu')
		state = torch.load('/home/work/lixudong/code_work/sgas/cnn/full_train_s3_2-20200608/weights.pt', map_location='cpu')
		model.load_state_dict(state)
		model = model.to(device)
		for i in range(args.start_epoch):
			scheduler.step()
		best_val_acc = 97.19#97.34#97.32#94.92#94.6#97.2
		
	for epoch in range(args.start_epoch, args.epochs):
		scheduler.step()
		logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
		model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
		train_acc, train_obj = train(train_queue, model, criterion, optimizer)
		logging.info('train_acc %f', train_acc)
		
		with torch.no_grad():
			valid_acc, valid_obj = infer(valid_queue, model, criterion)
			if valid_acc > best_val_acc:
				best_val_acc = valid_acc
				utils.save(model, os.path.join(args.save, 'best_weights.pt'))
			# logging.info('valid_acc %f\tbest_val_acc %f', valid_acc, best_val_acc)
			logging.info('val_acc: {:.6}, best_val_acc: \033[31m{:.6}\033[0m'.format(valid_acc, best_val_acc))
		
		state = {
			'epoch': epoch,
			'model_state': model.state_dict(),
			'optimizer': optimizer.state_dict(),
			'best_val_acc': best_val_acc
		}
		torch.save(state, os.path.join(args.save, 'weights.pt.tar'))
示例#3
0
def main():
    if not torch.cuda.is_available():
        logging.info('No GPU device available')
        sys.exit(1)
    np.random.seed(args.seed)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info("args = %s", args)
    logging.info("unparsed args = %s", unparsed)
    num_gpus = torch.cuda.device_count()

    genotype = eval("genotypes.%s" % args.arch)

    print('---------Genotype---------')
    logging.info(genotype)
    print('--------------------------')
    model = Network(args.init_channels, CIFAR_CLASSES, args.layers,
                    args.auxiliary, genotype)
    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    model = torch.nn.DataParallel(model)
    model = model.cuda()
    start_epch = 0
    if args.resume:
        MT = torch.load(os.path.join(args.save, 'weight_optimizers.pt'))
        model.load_state_dict(MT['net'])
        optimizer.load_state_dict(MT['optimizer'])
        start_epch = MT['epoch']

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    if args.cifar100:
        train_transform, valid_transform = utils._data_transforms_cifar100(
            args)
    else:
        train_transform, valid_transform = utils._data_transforms_cifar10(args)
    if args.cifar100:
        train_data = dset.CIFAR100(root=args.data_dir,
                                   train=True,
                                   download=True,
                                   transform=train_transform)
        valid_data = dset.CIFAR100(root=args.data_dir,
                                   train=False,
                                   download=True,
                                   transform=valid_transform)
    else:
        train_data = dset.CIFAR10(root=args.data_dir,
                                  train=True,
                                  download=True,
                                  transform=train_transform)
        valid_data = dset.CIFAR10(root=args.data_dir,
                                  train=False,
                                  download=True,
                                  transform=valid_transform)

    train_queue = torch.utils.data.DataLoader(train_data,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              pin_memory=True,
                                              num_workers=args.workers)

    valid_queue = torch.utils.data.DataLoader(valid_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=args.workers)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs))
    best_acc = 0.0
    for epoch in range(start_epch, args.epochs):

        model.module.drop_path_prob = args.drop_path_prob * epoch / args.epochs
        model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
        start_time = time.time()
        train_acc, train_obj = train(train_queue, model, criterion, optimizer)
        logging.info('Train_acc: %f', train_acc)

        scheduler.step()
        logging.info('Epoch: %d lr %e', epoch, scheduler.get_lr()[0])

        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        if valid_acc > best_acc:
            best_acc = valid_acc
            state = {
                'net': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch
            }
            torch.save(state,
                       os.path.join(args.save, 'best_weight_optimizers.pt'))
        logging.info('Valid_acc: %f, best_acc: %f', valid_acc, best_acc)
        end_time = time.time()
        duration = end_time - start_time
        print('Epoch time: %d h.' % (duration * (args.epochs - epoch) / 3600))
        if epoch % 50 == 0:
            state = {
                'net': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch
            }
            torch.save(state, os.path.join(args.save, 'weight_optimizers.pt'))
示例#4
0
def main(args):
    place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
        if args.use_data_parallel else fluid.CUDAPlace(0)

    with fluid.dygraph.guard(place):
        genotype = eval("genotypes.%s" % args.arch)
        model = Network(
            C=args.init_channels,
            num_classes=args.class_num,
            layers=args.layers,
            auxiliary=args.auxiliary,
            genotype=genotype)

        logger.info("param size = {:.6f}MB".format(
            count_parameters_in_MB(model.parameters())))

        device_num = fluid.dygraph.parallel.Env().nranks
        step_per_epoch = int(args.trainset_num /
                             (args.batch_size * device_num))
        learning_rate = fluid.dygraph.CosineDecay(args.learning_rate,
                                                  step_per_epoch, args.epochs)
        clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=args.grad_clip)
        optimizer = fluid.optimizer.MomentumOptimizer(
            learning_rate,
            momentum=args.momentum,
            regularization=fluid.regularizer.L2Decay(args.weight_decay),
            parameter_list=model.parameters(),
            grad_clip=clip)

        if args.use_data_parallel:
            strategy = fluid.dygraph.parallel.prepare_context()
            model = fluid.dygraph.parallel.DataParallel(model, strategy)

        train_loader = fluid.io.DataLoader.from_generator(
            capacity=64,
            use_double_buffer=True,
            iterable=True,
            return_list=True,
            use_multiprocess=args.use_multiprocess)
        valid_loader = fluid.io.DataLoader.from_generator(
            capacity=64,
            use_double_buffer=True,
            iterable=True,
            return_list=True,
            use_multiprocess=args.use_multiprocess)

        train_reader = reader.train_valid(
            batch_size=args.batch_size,
            is_train=True,
            is_shuffle=True,
            args=args)
        valid_reader = reader.train_valid(
            batch_size=args.batch_size,
            is_train=False,
            is_shuffle=False,
            args=args)
        if args.use_data_parallel:
            train_reader = fluid.contrib.reader.distributed_batch_reader(
                train_reader)

        train_loader.set_batch_generator(train_reader, places=place)
        valid_loader.set_batch_generator(valid_reader, places=place)

        save_parameters = (not args.use_data_parallel) or (
            args.use_data_parallel and
            fluid.dygraph.parallel.Env().local_rank == 0)
        best_acc = 0
        for epoch in range(args.epochs):
            drop_path_prob = args.drop_path_prob * epoch / args.epochs
            logger.info('Epoch {}, lr {:.6f}'.format(
                epoch, optimizer.current_step_lr()))
            train_top1 = train(model, train_loader, optimizer, epoch,
                               drop_path_prob, args)
            logger.info("Epoch {}, train_acc {:.6f}".format(epoch, train_top1))
            valid_top1 = valid(model, valid_loader, epoch, args)
            if valid_top1 > best_acc:
                best_acc = valid_top1
                if save_parameters:
                    fluid.save_dygraph(model.state_dict(),
                                       args.model_save_dir + "/best_model")
            logger.info("Epoch {}, valid_acc {:.6f}, best_valid_acc {:.6f}".
                        format(epoch, valid_top1, best_acc))
示例#5
0
def run(net,
        init_ch=32,
        layers=20,
        auxiliary=True,
        lr=0.025,
        momentum=0.9,
        wd=3e-4,
        cutout=True,
        cutout_length=16,
        data='../data',
        batch_size=96,
        epochs=600,
        drop_path_prob=0.2,
        auxiliary_weight=0.4):
    save = '/checkpoint/linnanwang/nasnet/' + hashlib.md5(
        json.dumps(net).encode()).hexdigest()
    utils.create_exp_dir(save, scripts_to_save=glob.glob('*.py'))

    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(save, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    np.random.seed(0)
    torch.cuda.set_device(0)
    cudnn.benchmark = True
    cudnn.enabled = True
    torch.manual_seed(0)
    logging.info('gpu device = %d' % 0)
    # logging.info("args = %s", args)

    genotype = net
    model = Network(init_ch, 10, layers, auxiliary, genotype).cuda()

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr,
                                momentum=momentum,
                                weight_decay=wd)
    model, optimizer = apex.amp.initialize(model, optimizer, opt_level="O3")

    train_transform, valid_transform = utils._data_transforms_cifar10(
        cutout, cutout_length)
    train_data = dset.CIFAR10(root=data,
                              train=True,
                              download=True,
                              transform=train_transform)
    valid_data = dset.CIFAR10(root=data,
                              train=False,
                              download=True,
                              transform=valid_transform)

    train_queue = torch.utils.data.DataLoader(train_data,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              pin_memory=True,
                                              num_workers=2)

    valid_queue = torch.utils.data.DataLoader(valid_data,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=2)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(epochs))

    best_acc = 0.0

    for epoch in range(epochs):
        scheduler.step()
        logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
        model.drop_path_prob = drop_path_prob * epoch / epochs

        train_acc, train_obj = train(train_queue,
                                     model,
                                     criterion,
                                     optimizer,
                                     auxiliary=auxiliary,
                                     auxiliary_weight=auxiliary_weight)
        logging.info('train_acc: %f', train_acc)

        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        logging.info('valid_acc: %f', valid_acc)

        if valid_acc > best_acc and epoch >= 50:
            print('this model is the best')
            torch.save(model.state_dict(), os.path.join(save, 'model.pt'))
        if valid_acc > best_acc:
            best_acc = valid_acc
        print('current best acc is', best_acc)

        if epoch == 100:
            break

        # utils.save(model, os.path.join(args.save, 'trained.pt'))
        print('saved to: model.pt')

    return best_acc
示例#6
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled=True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    genotype = eval("genotypes.%s" % args.arch)
    model = Network(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype)
    model = model.to('cuda')

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to('cuda')

    # not apply weight decay to BN layers
    if args.bn_no_decay:
        logging.info('BN layers are excluded from weight decay')
        bn_params, other_params = utils.split_bn_params(model)
        logging.debug('bn: %s', [p.dtype for p in bn_params])
        logging.debug('other: %s', [p.dtype for p in other_params])
        param_group = [{'params': bn_params, 'weight_decay': 0},
                       {'params': other_params}]
    else:
        param_group = model.parameters()

    optimizer = torch.optim.SGD(
        param_group,
        args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )
    logging.info('optimizer: %s', optimizer)

    train_transform, valid_transform = utils.data_transforms_cifar10(args)
    train_data = datasets.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
    valid_data = datasets.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)

    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.num_workers)

    valid_queue = torch.utils.data.DataLoader(
        valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.num_workers)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs))

    init_epoch = 0
    best_acc = 0

    if args.recover:
        states = torch.load(args.recover)
        model.load_state_dict(states['stete'])
        init_epoch = states['epoch'] + 1
        best_acc = states['best_acc']
        logging.info('checkpoint loaded')
        scheduler.step(init_epoch)
        logging.info('scheduler is set to epoch %d. learning rate is %s', init_epoch, scheduler.get_lr())

    for epoch in range(init_epoch, args.epochs):
        logging.info('epoch %d lr %s', epoch, scheduler.get_lr())
        model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

        train_acc, train_obj = train(train_queue, model, criterion, optimizer)
        logging.info('train_acc %.4f', train_acc)

        with torch.no_grad():
            valid_acc, valid_obj = infer(valid_queue, model, criterion)
            logging.info('valid_acc %f', valid_acc)

        logging.info('epoch %03d overall train_acc=%.4f valid_acc=%.4f', epoch, train_acc, valid_acc)

        scheduler.step()

        # gpu info
        utils.gpu_usage(args.debug)

        if valid_acc > best_acc:
            best_acc = valid_acc

        logging.info('best acc: %.4f', best_acc)

        utils.save_checkpoint(state={'stete': model.state_dict(),
                                     'epoch': epoch,
                                     'best_acc': best_acc,},
                              is_best=False, save=args.save)
示例#7
0
def main():

    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    cudnn.enabled = True
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)
    cur_epoch = 0

    net = eval(args.arch)
    print(net)
    code = gen_code_from_list(net, node_num=int((len(net) / 4)))
    genotype = translator([code, code], max_node=int((len(net) / 4)))
    print(genotype)

    model_ema = None

    if not continue_train:

        print('train from the scratch')
        model = Network(args.init_ch, 10, args.layers, args.auxiliary,
                        genotype).cuda()
        print("model init params values:", flatten_params(model))

        logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

        criterion = CutMixCrossEntropyLoss(True).cuda()

        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.wd)

        if args.model_ema:
            model_ema = ModelEma(
                model,
                decay=args.model_ema_decay,
                device='cpu' if args.model_ema_force_cpu else '')

    else:
        print('continue train from checkpoint')

        model = Network(args.init_ch, 10, args.layers, args.auxiliary,
                        genotype).cuda()

        criterion = CutMixCrossEntropyLoss(True).cuda()

        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.wd)

        logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

        checkpoint = torch.load(args.save + '/model.pt')
        model.load_state_dict(checkpoint['model_state_dict'])
        cur_epoch = checkpoint['epoch']
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        if args.model_ema:

            model_ema = ModelEma(
                model,
                decay=args.model_ema_decay,
                device='cpu' if args.model_ema_force_cpu else '',
                resume=args.save + '/model.pt')

    train_transform, valid_transform = utils._auto_data_transforms_cifar10(
        args)

    ds_train = dset.CIFAR10(root=args.data,
                            train=True,
                            download=True,
                            transform=train_transform)

    args.cv = -1
    if args.cv >= 0:
        sss = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=0)
        sss = sss.split(list(range(len(ds_train))), ds_train.targets)
        for _ in range(args.cv + 1):
            train_idx, valid_idx = next(sss)
        ds_valid = Subset(ds_train, valid_idx)
        ds_train = Subset(ds_train, train_idx)
    else:
        ds_valid = Subset(ds_train, [])

    train_queue = torch.utils.data.DataLoader(CutMix(ds_train,
                                                     10,
                                                     beta=1.0,
                                                     prob=0.5,
                                                     num_mix=2),
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=2,
                                              pin_memory=True)

    valid_queue = torch.utils.data.DataLoader(dset.CIFAR10(
        root=args.data, train=False, transform=valid_transform),
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=2,
                                              pin_memory=True)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs))

    best_acc = 0.0

    if continue_train:
        for i in range(cur_epoch + 1):
            scheduler.step()

    for epoch in range(cur_epoch, args.epochs):
        print('cur_epoch is', epoch)
        scheduler.step()
        logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
        model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

        if model_ema is not None:
            model_ema.ema.drop_path_prob = args.drop_path_prob * epoch / args.epochs

        train_acc, train_obj = train(train_queue, model, criterion, optimizer,
                                     epoch, model_ema)
        logging.info('train_acc: %f', train_acc)

        if model_ema is not None and not args.model_ema_force_cpu:
            valid_acc_ema, valid_obj_ema = infer(valid_queue,
                                                 model_ema.ema,
                                                 criterion,
                                                 ema=True)
            logging.info('valid_acc_ema %f', valid_acc_ema)

        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        logging.info('valid_acc: %f', valid_acc)

        if valid_acc > best_acc:
            best_acc = valid_acc
            print('this model is the best')
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict()
                }, os.path.join(args.save, 'top1.pt'))
        print('current best acc is', best_acc)
        logging.info('best_acc: %f', best_acc)

        if model_ema is not None:
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'state_dict_ema': get_state_dict(model_ema)
                }, os.path.join(args.save, 'model.pt'))

        else:
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict()
                }, os.path.join(args.save, 'model.pt'))

        print('saved to: trained.pt')
示例#8
0
文件: train.py 项目: lixingjian/darts
def main():
  if not torch.cuda.is_available():
    logging.info('no gpu device available')
    sys.exit(1)

  np.random.seed(args.seed)
  #torch.cuda.set_device(args.gpu)
  cudnn.benchmark = True
  torch.manual_seed(args.seed)
  cudnn.enabled=True
  torch.cuda.manual_seed(args.seed)
  logging.info('gpu device = %d' % args.gpu)
  logging.info("args = %s", args)

  genotype = eval("genotypes.%s" % args.arch)
  model = Network(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype, args.residual_wei, args.shrink_channel)
  model = model.cuda()

  logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

  criterion = nn.CrossEntropyLoss()
  criterion = criterion.cuda()
  optimizer = torch.optim.SGD(
      model.parameters(),
      args.learning_rate,
      momentum=args.momentum,
      weight_decay=args.weight_decay
      )

  resume = os.path.join(args.save, 'checkpoint.pth.tar')
  if os.path.exists(resume):
    print("=> loading checkpoint %s" % resume)
    #checkpoint = torch.load(resume)
    checkpoint = torch.load(resume, map_location = lambda storage, loc: storage.cuda(0))
    args.start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    #optimizer.load_state_dict(checkpoint['optimizer'])
    optimizer.state_dict()['state'] = checkpoint['optimizer']['state']
    print('=> loaded checkpoint epoch %d' % args.start_epoch)
    if args.start_epoch >= args.epochs:
        print('training finished')
        sys.exit(0)

  train_transform, valid_transform = utils._data_transforms_cifar10(args)
  train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
  valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)

  train_queue = torch.utils.data.DataLoader(
      train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=2)

  valid_queue = torch.utils.data.DataLoader(
      valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=2)

  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs))
  for epoch in range(args.start_epoch):
    scheduler.step()

  for epoch in range(args.start_epoch, args.epochs):
    scheduler.step()
    logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
    model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

    train_acc, train_obj = train(train_queue, model, criterion, optimizer)
    logging.info('train_acc %f', train_acc)

    if epoch == args.epochs - 1:
        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        logging.info('valid_acc %f', valid_acc)

    utils.save_checkpoint({
      'epoch': epoch + 1,
      'state_dict': model.state_dict(),
      'best_acc_top1': train_acc,
      'optimizer' : optimizer.state_dict(),
      }, False, args.save)
示例#9
0
valid_data = dset.CIFAR10(root='/home/xiaoda/data',
                          train=False,
                          download=False,
                          transform=train_transform)
writer = SummaryWriter(args.tensorboard_log)
if not os.path.exists(args.results_dir):
    os.mkdir(args.results_dir)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                       float(args.epochs))
train_queue = torch.utils.data.DataLoader(train_data,
                                          batch_size=args.batch_size,
                                          shuffle=True)

valid_queue = torch.utils.data.DataLoader(valid_data,
                                          batch_size=args.batch_size,
                                          shuffle=False)
for epoch in range(args.epochs):
    scheduler.step()
    print("Start to train for epoch %d" % (epoch))
    model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
    train_acc, train_obj = train(train_queue, model, criterion, optimizer)
    print("Start to validate for epoch %d" % (epoch))
    valid_acc, valid_obj = infer(valid_queue, model, criterion)

    writer.add_scalar('Train/train_acc', train_acc, epoch)
    writer.add_scalar('Train/train_loss', train_obj, epoch)
    writer.add_scalar('Val/valid_acc', valid_acc, epoch)
    writer.add_scalar('Val/valid_loss', valid_obj, epoch)

torch.save(model.state_dict(), args.results_dir + '/weights.pt')