Beispiel #1
0
    def run(self):
        self.logger.info('args = %s', self.args)
        run_start = time.time()
        for epoch in range(self.args.start_epoch, self.args.epochs):
            self.scheduler.step()
            self.lr = self.scheduler.get_lr()[0]
            self.logger.info('epoch %d / %d lr %e', epoch, self.args.epochs,
                             self.lr)

            # construct genotype and update topology graph
            genotype = self.model.genotype()

            self.logger.info('genotype = %s', genotype)

            print('alphas normal: \n',
                  F.softmax(self.model.alphas_normal, dim=-1))
            print('alphas reduce: \n',
                  F.softmax(self.model.alphas_reduce, dim=-1))

            # train and search the model
            train_acc, train_obj = self.train()

            # valid the model
            valid_acc, valid_obj = self.infer()
            self.logger.info('valid_acc %f', valid_acc)

            # save checkpoint
            dutils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'use_sparse': self.args.use_sparse,
                    'state_dict': self.model.state_dict(),
                    'dur_time': self.dur_time + time.time() - run_start,
                    'arch_parameter': self.model._arch_parameters,
                    'alphas_normal': self.model.alphas_normal,
                    'alphas_reduce': self.model.alphas_reduce,
                    'arch_optimizer': self.architect.optimizer.state_dict(),
                    'scheduler': self.scheduler.state_dict(),
                    'optimizer': self.optimizer.state_dict()
                },
                is_best=False,
                save=args.save)
            self.logger.info(
                'save checkpoint (epoch %d) in %s  dur_time: %s', epoch,
                self.args.save,
                dutils.calc_time(self.dur_time + time.time() - run_start))
        with open(self.args.save + "/genotype.txt", "w") as f:
            f.write(str(genotype))
Beispiel #2
0
    def run(self):
        self.logger.info('args = %s', self.args)
        run_start = time.time()
        for epoch in range(self.args.start_epoch, self.args.epochs):
            self.scheduler.step()
            self.logger.info('epoch % d / %d  lr %e', epoch, self.args.epochs,
                             self.scheduler.get_lr()[0])

            if self.args.no_dropout:
                self.model._drop_path_prob = 0
            else:
                self.model._drop_path_prob = self.args.drop_path_prob * epoch / self.args.epochs
                self.logger.info('drop_path_prob %e',
                                 self.model._drop_path_prob)

            train_acc, train_obj = self.train()
            self.logger.info('train loss %e, train acc %f', train_obj,
                             train_acc)

            valid_acc_top1, valid_acc_top5, valid_obj = self.infer()
            self.logger.info(
                'valid loss %e, top1 valid acc %f top5 valid acc %f',
                valid_obj, valid_acc_top1, valid_acc_top5)
            self.logger.info('best valid acc %f', self.best_acc_top1)

            is_best = False
            if valid_acc_top1 > self.best_acc_top1:
                self.best_acc_top1 = valid_acc_top1
                is_best = True

            dutils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'dur_time': self.dur_time + time.time() - run_start,
                    'state_dict': self.model.state_dict(),
                    'drop_path_prob': self.args.drop_path_prob,
                    'best_acc_top1': self.best_acc_top1,
                    'optimizer': self.optimizer.state_dict(),
                    'scheduler': self.scheduler.state_dict()
                }, is_best, self.args.save)
        self.logger.info(
            'train epoches %d, best_acc_top1 %f, dur_time %s',
            self.args.epochs, self.best_acc_top1,
            dutils.calc_time(self.dur_time + time.time() - run_start))
Beispiel #3
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    args.distributed = args.world_size > 1
    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.dist_rank)

    # create model
    if args.pretrained:
        print('=> using pre-train model {}'.format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print('=> creating model {}', format(args.arch))
        model = models.__dict__[args.arch]()

    if not args.gpu is not None:
        model = model.cuda(args.gpu)
    elif args.distributed:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)
    else:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print('=> loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(), normalize
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        dutils.adjust_learning_rate(optimizer, epoch, args.lr)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prev@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        dutils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict()
            }, is_best)