Exemplo n.º 1
0
def main():
    np.random.seed(args.seed)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info("device = %s" % 'cuda:{}'.format(args.gpu) if torch.cuda.is_available() else 'cpu')
    logging.info("args = %s", args)

    genotype = eval("genotypes.%s" % args.arch)
    model = Network(args.init_channels, CLASSES, args.layers, args.auxiliary, genotype)
    if args.parallel:
        model = nn.DataParallel(model).to(device)
    else:
        model = model.to(device)

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    criterion = nn.CrossEntropyLoss().to(device)
    criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth).to(device)

    optimizer = torch.optim.SGD(
        model.parameters(),
        args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay,
    )

    traindir = os.path.join(args.data, "train")
    validdir = os.path.join(args.data, "val")
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )
    train_data = dset.ImageFolder(
        traindir,
        transforms.Compose(
            [
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(
                    brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2
                ),
                transforms.ToTensor(),
                normalize,
            ]
        ),
    )
    valid_data = dset.ImageFolder(
        validdir,
        transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]
        ),
    )

    train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=4,
    )

    valid_queue = torch.utils.data.DataLoader(
        valid_data,
        batch_size=args.batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=4,
    )

    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, args.decay_period, gamma=args.gamma
    )

    best_acc_top1 = 0
    for epoch in range(args.epochs):
        scheduler.step()
        logging.info("epoch %d lr %e", epoch, scheduler.get_lr()[0])
        model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

        train_acc, train_obj = train(train_queue, model, criterion_smooth, optimizer)
        logging.info("train_acc %f", train_acc)

        valid_acc_top1, valid_acc_top5, valid_obj = infer(valid_queue, model, criterion)
        logging.info("valid_acc_top1 %f", valid_acc_top1)
        logging.info("valid_acc_top5 %f", valid_acc_top5)

        is_best = False
        if valid_acc_top1 > best_acc_top1:
            best_acc_top1 = valid_acc_top1
            is_best = True

        utils.save_checkpoint(
            {
                "epoch": epoch + 1,
                "state_dict": model.state_dict(),
                "best_acc_top1": best_acc_top1,
                "optimizer": optimizer.state_dict(),
            },
            is_best,
            args.save,
        )
Exemplo n.º 2
0
class Trainer:
    def __init__(self,
                 args: Namespace,
                 genotype: Genotype,
                 my_dataset: MyDataset,
                 choose_cell=False):

        self.__args = args
        self.__dataset = my_dataset
        self.__previous_epochs = 0

        if args.seed is None:
            raise Exception('designate seed.')
        elif args.epochs is None:
            raise Exception('designate epochs.')
        if not (args.arch or args.arch_path):
            raise Exception('need to designate arch.')

        log_format = '%(asctime)s %(message)s'
        logging.basicConfig(stream=sys.stdout,
                            level=logging.INFO,
                            format=log_format,
                            datefmt='%m/%d %I:%M:%S %p')
        fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
        fh.setFormatter(logging.Formatter(log_format))
        logging.getLogger().addHandler(fh)
        np.random.seed(args.seed)
        cudnn.benchmark = True
        cudnn.enabled = True
        torch.manual_seed(args.seed)

        logging.info(f'gpu device = {args.gpu}')
        logging.info(f'args = {args}')

        logging.info(f'Train genotype: {genotype}')

        if my_dataset == MyDataset.CIFAR10:
            self.model = NetworkCIFAR(args.init_ch, 10, args.layers,
                                      args.auxiliary, genotype)
            train_transform, valid_transform = utils._data_transforms_cifar10(
                args)
            train_data = dset.CIFAR10(root=args.data,
                                      train=True,
                                      download=True,
                                      transform=train_transform)
            valid_data = dset.CIFAR10(root=args.data,
                                      train=False,
                                      download=True,
                                      transform=valid_transform)

        elif my_dataset == MyDataset.CIFAR100:
            self.model = NetworkCIFAR(args.init_ch, 100, args.layers,
                                      args.auxiliary, genotype)
            train_transform, valid_transform = utils._data_transforms_cifar100(
                args)
            train_data = dset.CIFAR100(root=args.data,
                                       train=True,
                                       download=True,
                                       transform=train_transform)
            valid_data = dset.CIFAR100(root=args.data,
                                       train=False,
                                       download=True,
                                       transform=valid_transform)

        elif my_dataset == MyDataset.ImageNet:
            self.model = NetworkImageNet(args.init_ch, 1000, args.layers,
                                         args.auxiliary, genotype)
            self.__criterion_smooth = CrossEntropyLabelSmooth(
                1000, args.label_smooth).to(device)
            traindir = os.path.join(args.data, 'train')
            validdir = os.path.join(args.data, 'val')
            normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])
            train_data = dset.ImageFolder(
                traindir,
                transforms.Compose([
                    transforms.RandomResizedCrop(224),
                    transforms.RandomHorizontalFlip(),
                    transforms.ColorJitter(brightness=0.4,
                                           contrast=0.4,
                                           saturation=0.4,
                                           hue=0.2),
                    transforms.ToTensor(),
                    normalize,
                ]))
            valid_data = dset.ImageFolder(
                validdir,
                transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    normalize,
                ]))
        else:
            raise Exception('No match Dataset')

        checkpoint = None
        if use_DataParallel:
            print('use Data Parallel')
            if args.checkpoint_path:
                checkpoint = torch.load(args.checkpoint_path)
                utils.load(self.model, checkpoint['state_dict'],
                           args.to_parallel)
                self.__previous_epochs = checkpoint['epoch']
                args.epochs -= self.__previous_epochs
                if args.epochs <= 0:
                    raise Exception('args.epochs is too small.')

            self.model = nn.DataParallel(self.model)
            self.__module = self.model.module
            torch.cuda.manual_seed_all(args.seed)
        else:
            if args.checkpoint_path:
                checkpoint = torch.load(args.checkpoint_path)
                utils.load(self.model, checkpoint['state_dict'],
                           args.to_parallel)
                args.epochs -= checkpoint['epoch']
                if args.epochs <= 0:
                    raise Exception('args.epochs is too small.')
            torch.cuda.manual_seed(args.seed)
            self.__module = self.model

        self.model.to(device)

        param_size = utils.count_parameters_in_MB(self.model)
        logging.info(f'param size = {param_size}MB')

        self.__criterion = nn.CrossEntropyLoss().to(device)

        self.__optimizer = torch.optim.SGD(self.__module.parameters(),
                                           args.lr,
                                           momentum=args.momentum,
                                           weight_decay=args.wd)
        if checkpoint:
            self.__optimizer.load_state_dict(checkpoint['optimizer'])

        num_workers = torch.cuda.device_count() * 4
        if choose_cell:
            num_train = len(train_data)  # 50000
            indices = list(range(num_train))
            split = int(np.floor(args.train_portion * num_train))  # 25000

            self.__train_queue = torch.utils.data.DataLoader(
                train_data,
                batch_size=args.batchsz,
                sampler=torch.utils.data.sampler.SubsetRandomSampler(
                    indices[:split]),
                pin_memory=True,
                num_workers=num_workers)

            self.__valid_queue = torch.utils.data.DataLoader(
                train_data,
                batch_size=args.batchsz,
                sampler=torch.utils.data.sampler.SubsetRandomSampler(
                    indices[split:]),
                pin_memory=True,
                num_workers=num_workers)
        else:
            self.__train_queue = torch.utils.data.DataLoader(
                train_data,
                batch_size=args.batchsz,
                shuffle=True,
                pin_memory=True,
                num_workers=num_workers)

            self.__valid_queue = torch.utils.data.DataLoader(
                valid_data,
                batch_size=args.batchsz,
                shuffle=False,
                pin_memory=True,
                num_workers=num_workers)

        if my_dataset == MyDataset.CIFAR10 or MyDataset.CIFAR100:
            self.__scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                self.__optimizer, args.epochs)
        elif my_dataset == MyDataset.ImageNet:
            self.__scheduler = torch.optim.lr_scheduler.StepLR(
                self.__optimizer, args.decay_period, gamma=args.gamma)
        else:
            raise Exception('No match Dataset')

        if checkpoint:
            self.__scheduler.load_state_dict(checkpoint['scheduler'])

    def __train_epoch(self, train_queue, model, criterion, optimizer, epoch):
        objs = utils.AverageMeter()
        top1 = utils.AverageMeter()
        top5 = utils.AverageMeter()
        model.train()

        with tqdm(train_queue) as progress:
            progress.set_description_str(f'Train epoch {epoch}')

            for step, (x, target) in enumerate(progress):

                x, target = x.to(device), target.to(device, non_blocking=True)

                optimizer.zero_grad()
                logits, logits_aux = model(x)
                loss = criterion(logits, target)
                if self.__args.auxiliary:
                    loss_aux = criterion(logits_aux, target)
                    loss += self.__args.auxiliary_weight * loss_aux
                loss.backward()
                nn.utils.clip_grad_norm_(
                    model.module.parameters() if use_DataParallel else
                    model.parameters(), self.__args.grad_clip)
                optimizer.step()

                prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
                n = x.size(0)
                objs.update(loss.item(), n)
                top1.update(prec1.item(), n)
                top5.update(prec5.item(), n)

                progress.set_postfix_str(f'loss: {objs.avg}, top1: {top1.avg}')

                if step % self.__args.report_freq == 0:
                    logging.info(
                        f'Step:{step:03} loss:{objs.avg} acc1:{top1.avg} acc5:{top5.avg}'
                    )

        return top1.avg, objs.avg

    def __infer_epoch(self, valid_queue, model, criterion, epoch):
        objs = utils.AverageMeter()
        top1 = utils.AverageMeter()
        top5 = utils.AverageMeter()
        model.eval()

        with tqdm(valid_queue) as progress:
            for step, (x, target) in enumerate(progress):
                progress.set_description_str(f'Valid epoch {epoch}')

                x = x.to(device)
                target = target.to(device, non_blocking=True)

                with torch.no_grad():
                    logits, _ = model(x)
                    loss = criterion(logits, target)

                    prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
                    n = x.size(0)
                    objs.update(loss.item(), n)
                    top1.update(prec1.item(), n)
                    top5.update(prec5.item(), n)

                progress.set_postfix_str(f'loss: {objs.avg}, top1: {top1.avg}')

                if step % self.__args.report_freq == 0:
                    logging.info(
                        f'>> Validation: {step:03} {objs.avg} {top1.avg} {top5.avg}'
                    )

        return top1.avg, top5.avg, objs.avg

    def train(self) -> Tuple[float, float, float]:

        best_acc_top1 = 0
        for epoch in tqdm(range(self.__args.epochs), desc='Total Progress'):
            self.__scheduler.step()
            logging.info(f'epoch {epoch} lr {self.__scheduler.get_lr()[0]}')
            self.__module.drop_path_prob = self.__args.drop_path_prob * epoch / self.__args.epochs

            train_acc, train_obj = self.__train_epoch(self.__train_queue,
                                                      self.model,
                                                      self.__criterion,
                                                      self.__optimizer,
                                                      epoch + 1)
            logging.info(f'train_acc: {train_acc}')

            valid_acc_top1, valid_acc_top5, valid_obj = self.__infer_epoch(
                self.__valid_queue, self.model, self.__criterion, epoch + 1)
            logging.info(f'valid_acc: {valid_acc_top1}')
            if self.__dataset == MyDataset.ImageNet:
                logging.info(f'valid_acc_top5 {valid_acc_top5}')

            is_best = False
            if valid_acc_top1 > best_acc_top1:
                best_acc_top1 = valid_acc_top1
                is_best = True

            utils.save(self.model, os.path.join(self.__args.save,
                                                'trained.pt'))
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1 + self.__previous_epochs,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.__optimizer.state_dict(),
                    'scheduler': self.__scheduler.state_dict()
                },
                is_best=is_best,
                save_path=self.__args.save)
            print('saved to: trained.pt and checkpoint.pth.tar')

        return train_acc, valid_acc_top1, best_acc_top1