Ejemplo n.º 1
0
def main():
    epochs = 450
    device = torch.device('cuda')
    data_dir = '../Dataset/IMAGENET'
    num_gpu = torch.cuda.device_count()
    v_batch_size = 16 * num_gpu
    t_batch_size = 256 * num_gpu

    model = nn.EfficientNet(num_class, version[0], version[1],
                            version[3]).to(device)
    optimizer = nn.RMSprop(util.add_weight_decay(model),
                           0.012 * num_gpu,
                           0.9,
                           1e-3,
                           momentum=0.9)

    model = torch.nn.DataParallel(model)
    _ = model(torch.zeros(1, 3, version[2], version[2]).to(device))

    ema = nn.EMA(model)
    t_criterion = nn.CrossEntropyLoss().to(device)
    v_criterion = torch.nn.CrossEntropyLoss().to(device)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    t_dataset = datasets.ImageFolder(
        os.path.join(data_dir, 'train'),
        transforms.Compose([
            util.RandomResize(version[2]),
            transforms.ColorJitter(0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(), normalize
        ]))
    v_dataset = datasets.ImageFolder(
        os.path.join(data_dir, 'val'),
        transforms.Compose([
            transforms.Resize(version[2] + 32),
            transforms.CenterCrop(version[2]),
            transforms.ToTensor(), normalize
        ]))

    t_loader = data.DataLoader(t_dataset,
                               batch_size=t_batch_size,
                               shuffle=True,
                               num_workers=os.cpu_count(),
                               pin_memory=True)
    v_loader = data.DataLoader(v_dataset,
                               batch_size=v_batch_size,
                               shuffle=False,
                               num_workers=os.cpu_count(),
                               pin_memory=True)

    scheduler = nn.StepLR(optimizer)
    amp_scale = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())
    with open(f'weights/{scheduler.__str__()}.csv', 'w') as summary:
        writer = csv.DictWriter(
            summary,
            fieldnames=['epoch', 't_loss', 'v_loss', 'acc@1', 'acc@5'])
        writer.writeheader()
        best_acc1 = 0
        for epoch in range(0, epochs):
            print(('\n' + '%10s' * 2) % ('epoch', 'loss'))
            t_bar = tqdm.tqdm(t_loader, total=len(t_loader))
            model.train()
            t_loss = util.AverageMeter()
            v_loss = util.AverageMeter()
            for images, target in t_bar:
                loss, _, _, _ = batch_fn(images, target, model, device,
                                         t_criterion)
                optimizer.zero_grad()
                amp_scale.scale(loss).backward()
                amp_scale.step(optimizer)
                amp_scale.update()

                ema.update(model)
                torch.cuda.synchronize()
                t_loss.update(loss.item(), images.size(0))

                t_bar.set_description(('%10s' + '%10.4g') %
                                      ('%g/%g' % (epoch + 1, epochs), loss))
            top1 = util.AverageMeter()
            top5 = util.AverageMeter()

            ema_model = ema.model.eval()
            with torch.no_grad():
                for images, target in tqdm.tqdm(v_loader, ('%10s' * 2) %
                                                ('acc@1', 'acc@5')):
                    loss, acc1, acc5, output = batch_fn(
                        images, target, ema_model, device, v_criterion, False)
                    torch.cuda.synchronize()
                    v_loss.update(loss.item(), output.size(0))
                    top1.update(acc1.item(), images.size(0))
                    top5.update(acc5.item(), images.size(0))
                acc1, acc5 = top1.avg, top5.avg
                print('%10.3g' * 2 % (acc1, acc5))

            scheduler.step(epoch + 1)
            writer.writerow({
                'epoch': epoch + 1,
                't_loss': str(f'{t_loss.avg:.4f}'),
                'v_loss': str(f'{v_loss.avg:.4f}'),
                'acc@1': str(f'{acc1:.3f}'),
                'acc@5': str(f'{acc5:.3f}')
            })
            util.save_checkpoint({'state_dict': ema.model.state_dict()},
                                 acc1 > best_acc1)
            best_acc1 = max(acc1, best_acc1)
    torch.cuda.empty_cache()
Ejemplo n.º 2
0
def train(args):
    epochs = 350
    batch_size = 288
    util.set_seeds(args.rank)
    model = nn.EfficientNet().cuda()
    lr = batch_size * torch.cuda.device_count() * 0.256 / 4096
    optimizer = nn.RMSprop(util.add_weight_decay(model), lr, 0.9, 1e-3, momentum=0.9)
    ema = nn.EMA(model)

    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
    else:
        model = torch.nn.DataParallel(model)
    criterion = nn.CrossEntropyLoss().cuda()
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'),
                                   transforms.Compose([util.RandomResize(),
                                                       transforms.ColorJitter(0.4, 0.4, 0.4),
                                                       transforms.RandomHorizontalFlip(),
                                                       util.RandomAugment(),
                                                       transforms.ToTensor(), normalize]))
    if args.distributed:
        sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    else:
        sampler = None

    loader = data.DataLoader(dataset, batch_size, sampler=sampler, num_workers=8, pin_memory=True)

    scheduler = nn.StepLR(optimizer)
    amp_scale = torch.cuda.amp.GradScaler()
    with open(f'weights/{scheduler.__str__()}.csv', 'w') as f:
        if args.local_rank == 0:
            writer = csv.DictWriter(f, fieldnames=['epoch', 'acc@1', 'acc@5'])
            writer.writeheader()
        best_acc1 = 0
        for epoch in range(0, epochs):
            if args.distributed:
                sampler.set_epoch(epoch)
            if args.local_rank == 0:
                print(('\n' + '%10s' * 2) % ('epoch', 'loss'))
                bar = tqdm.tqdm(loader, total=len(loader))
            else:
                bar = loader
            model.train()
            for images, target in bar:
                loss = batch(images, target, model, criterion)
                optimizer.zero_grad()
                amp_scale.scale(loss).backward()
                amp_scale.step(optimizer)
                amp_scale.update()

                ema.update(model)
                torch.cuda.synchronize()
                if args.local_rank == 0:
                    bar.set_description(('%10s' + '%10.4g') % ('%g/%g' % (epoch + 1, epochs), loss))

            scheduler.step(epoch + 1)
            if args.local_rank == 0:
                acc1, acc5 = test(ema.model.eval())
                writer.writerow({'acc@1': str(f'{acc1:.3f}'),
                                 'acc@5': str(f'{acc5:.3f}'),
                                 'epoch': str(epoch + 1).zfill(3)})
                util.save_checkpoint({'state_dict': ema.model.state_dict()}, acc1 > best_acc1)
                best_acc1 = max(acc1, best_acc1)
    if args.distributed:
        torch.distributed.destroy_process_group()
    torch.cuda.empty_cache()