Ejemplo n.º 1
0
def test(model=None):
    if model is None:
        model = nn.EfficientNet()
        model.load_state_dict(torch.load('weights/best.pt', 'cpu')['state_dict'])
        model = model.cuda()
        model.eval()

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    dataset = datasets.ImageFolder(os.path.join(data_dir, 'val'),
                                   transforms.Compose([transforms.Resize(416),
                                                       transforms.CenterCrop(384),
                                                       transforms.ToTensor(), normalize]))

    loader = data.DataLoader(dataset, 48, num_workers=os.cpu_count(), pin_memory=True)
    top1 = util.AverageMeter()
    top5 = util.AverageMeter()
    with torch.no_grad():
        for images, target in tqdm.tqdm(loader, ('%10s' * 2) % ('acc@1', 'acc@5')):
            acc1, acc5 = batch(images, target, model)
            torch.cuda.synchronize()
            top1.update(acc1.item(), images.size(0))
            top5.update(acc5.item(), images.size(0))
        acc1, acc5 = top1.avg, top5.avg
        print('%10.3g' * 2 % (acc1, acc5))
    if model is None:
        torch.cuda.empty_cache()
    else:
        return acc1, acc5
Ejemplo n.º 2
0
def evaluate():
    model = nn.EfficientNet()
    model.load_state_dict(torch.load(os.path.join('weights', 'best.pt'), map_location='cpu')['state_dict'])
    model = model.to(device)
    model.eval()
    v_criterion = torch.nn.CrossEntropyLoss().to(device)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    v_dataset = datasets.ImageFolder(os.path.join(data_dir, 'val'),
                                     transforms.Compose([transforms.Resize(416),
                                                         transforms.CenterCrop(384),
                                                         transforms.ToTensor(), normalize]))

    v_loader = data.DataLoader(v_dataset, batch_size=64, shuffle=False,
                               num_workers=4, pin_memory=True)
    top1 = util.AverageMeter()
    top5 = util.AverageMeter()
    with torch.no_grad():
        for images, target in tqdm.tqdm(v_loader, ('%10s' * 2) % ('acc@1', 'acc@5')):
            loss, acc1, acc5, output = batch_fn(images, target, model, v_criterion, False)
            torch.cuda.synchronize()
            top1.update(acc1.item(), images.size(0))
            top5.update(acc5.item(), images.size(0))
        acc1, acc5 = top1.avg, top5.avg
        print('%10.3g' * 2 % (acc1, acc5))

    torch.cuda.empty_cache()
Ejemplo n.º 3
0
def evaluate(model, data_loader, global_stats, mode='train'):
    # Use precision for classify
    eval_time = util.Timer()
    start_acc = util.AverageMeter()

    # Make predictions
    examples = 0
    for ex in data_loader:
        batch_size = ex[0].size(0)
        pred_s = model.predict(ex)
        answer = ex[5]
        # We get metrics for independent start/end and joint start/end
        start_acc.update(Evaluate.accuracies(pred_s, answer.cpu().data.numpy()), 1)

        # If getting train accuracies, sample max 10k
        examples += batch_size
        if mode == 'train' and examples >= 1e4:
            break

    logger.info('%s valid unofficial use Accuracy: Epoch = %d | acc = %.2f | ' %
                (mode, global_stats['epoch'], start_acc.avg) +
                ' = %d | ' %
                (examples) +
                'valid time = %.2f (s)' % eval_time.time())

    return {'acc': start_acc.avg}
Ejemplo n.º 4
0
    def test(self):

        leader_accuracy = utils.AverageMeter()
        self.leader_model.eval()

        start_time = time.time()
        with torch.no_grad():
            for _, (inputs, labels) in enumerate(self.testLoader):

                inputs, labels = inputs.to(self.device), labels.to(self.device)
                leader_output, _ = self.leader_model(inputs)

                leader_prec = utils.accuracy(leader_output, labels.data, topk=(1, ))
                leader_accuracy.update(leader_prec[0], inputs.size(0))

            current_time = time.time()

            print('Model[{}]:\tAccuracy {:.2f}%\tTime {:.2f}s'
                  .format('Leader', float(leader_accuracy.avg), (current_time - start_time)))
Ejemplo n.º 5
0
def validate(model, opt):
    # --- start evaluation loop --- 
    logging.info('Starting evaluation loop ...')
    model.reset()
    assert(not model.net_D.training)
    val_dset = PairedDataset(opt, os.path.join(opt.real_im_path, 'val'),
                             os.path.join(opt.fake_im_path, 'val'),
                             is_val=True)
    val_dl = DataLoader(val_dset, batch_size=opt.batch_size // 2,
                        num_workers=opt.nThreads, pin_memory=False,
                        shuffle=False)
    val_losses = OrderedDict([(k + '_val', util.AverageMeter())
                              for k in model.loss_names])
    fake_label = opt.fake_class_id
    real_label = 1 - fake_label
    val_start_time = time.time()
    for i, ims in enumerate(val_dl):
        ims_real = ims['original'].to(opt.gpu_ids[0])
        ims_fake = ims['manipulated'].to(opt.gpu_ids[0])
        labels_real = real_label * torch.ones(ims_real.shape[0], dtype=torch.long).to(opt.gpu_ids[0])
        labels_fake = fake_label * torch.ones(ims_fake.shape[0], dtype=torch.long).to(opt.gpu_ids[0])

        inputs = dict(ims=torch.cat((ims_real, ims_fake), axis=0),
                      labels=torch.cat((labels_real, labels_fake), axis=0))

        # forward pass
        model.reset()
        model.set_input(inputs)
        model.test(True)
        losses = model.get_current_losses()

        # update val losses
        for k, v in losses.items():
            val_losses[k + '_val'].update(v, n=len(inputs['labels']))

    # get average val losses
    for k, v in val_losses.items():
        val_losses[k] = v.avg

    return val_losses
Ejemplo n.º 6
0
def train(args, data_loader, model, global_stats):
    """Run through one epoch of model training with the provided data loader."""
    # Initialize meters + timers
    train_loss = util.AverageMeter()
    epoch_time = util.Timer()
    # Run one epoch
    for idx, ex in enumerate(data_loader):
        train_loss.update(*model.update(ex))  # run on one batch

        if idx % args.display_iter == 0:
            logger.info('train: Epoch = %d | iter = %d/%d | ' %
                        (global_stats['epoch'], idx, len(data_loader)) +
                        'loss = %.2f | elapsed time = %.2f (s)' %
                        (train_loss.avg, global_stats['timer'].time()))
            train_loss.reset()
    logger.info('train: Epoch %d done. Time for epoch = %.2f (s)' %
                (global_stats['epoch'], epoch_time.time()))

    # Checkpoint
    if args.checkpoint:
        model.checkpoint(args.model_file + '.checkpoint',
                         global_stats['epoch'] + 1)
Ejemplo n.º 7
0
def main():
    epochs = 450
    device = torch.device('cuda')
    data_dir = '../Dataset/IMAGENET'
    num_gpu = torch.cuda.device_count()
    v_batch_size = 16 * num_gpu
    t_batch_size = 256 * num_gpu

    model = nn.EfficientNet(num_class, version[0], version[1],
                            version[3]).to(device)
    optimizer = nn.RMSprop(util.add_weight_decay(model),
                           0.012 * num_gpu,
                           0.9,
                           1e-3,
                           momentum=0.9)

    model = torch.nn.DataParallel(model)
    _ = model(torch.zeros(1, 3, version[2], version[2]).to(device))

    ema = nn.EMA(model)
    t_criterion = nn.CrossEntropyLoss().to(device)
    v_criterion = torch.nn.CrossEntropyLoss().to(device)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    t_dataset = datasets.ImageFolder(
        os.path.join(data_dir, 'train'),
        transforms.Compose([
            util.RandomResize(version[2]),
            transforms.ColorJitter(0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(), normalize
        ]))
    v_dataset = datasets.ImageFolder(
        os.path.join(data_dir, 'val'),
        transforms.Compose([
            transforms.Resize(version[2] + 32),
            transforms.CenterCrop(version[2]),
            transforms.ToTensor(), normalize
        ]))

    t_loader = data.DataLoader(t_dataset,
                               batch_size=t_batch_size,
                               shuffle=True,
                               num_workers=os.cpu_count(),
                               pin_memory=True)
    v_loader = data.DataLoader(v_dataset,
                               batch_size=v_batch_size,
                               shuffle=False,
                               num_workers=os.cpu_count(),
                               pin_memory=True)

    scheduler = nn.StepLR(optimizer)
    amp_scale = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())
    with open(f'weights/{scheduler.__str__()}.csv', 'w') as summary:
        writer = csv.DictWriter(
            summary,
            fieldnames=['epoch', 't_loss', 'v_loss', 'acc@1', 'acc@5'])
        writer.writeheader()
        best_acc1 = 0
        for epoch in range(0, epochs):
            print(('\n' + '%10s' * 2) % ('epoch', 'loss'))
            t_bar = tqdm.tqdm(t_loader, total=len(t_loader))
            model.train()
            t_loss = util.AverageMeter()
            v_loss = util.AverageMeter()
            for images, target in t_bar:
                loss, _, _, _ = batch_fn(images, target, model, device,
                                         t_criterion)
                optimizer.zero_grad()
                amp_scale.scale(loss).backward()
                amp_scale.step(optimizer)
                amp_scale.update()

                ema.update(model)
                torch.cuda.synchronize()
                t_loss.update(loss.item(), images.size(0))

                t_bar.set_description(('%10s' + '%10.4g') %
                                      ('%g/%g' % (epoch + 1, epochs), loss))
            top1 = util.AverageMeter()
            top5 = util.AverageMeter()

            ema_model = ema.model.eval()
            with torch.no_grad():
                for images, target in tqdm.tqdm(v_loader, ('%10s' * 2) %
                                                ('acc@1', 'acc@5')):
                    loss, acc1, acc5, output = batch_fn(
                        images, target, ema_model, device, v_criterion, False)
                    torch.cuda.synchronize()
                    v_loss.update(loss.item(), output.size(0))
                    top1.update(acc1.item(), images.size(0))
                    top5.update(acc5.item(), images.size(0))
                acc1, acc5 = top1.avg, top5.avg
                print('%10.3g' * 2 % (acc1, acc5))

            scheduler.step(epoch + 1)
            writer.writerow({
                'epoch': epoch + 1,
                't_loss': str(f'{t_loss.avg:.4f}'),
                'v_loss': str(f'{v_loss.avg:.4f}'),
                'acc@1': str(f'{acc1:.3f}'),
                'acc@5': str(f'{acc5:.3f}')
            })
            util.save_checkpoint({'state_dict': ema.model.state_dict()},
                                 acc1 > best_acc1)
            best_acc1 = max(acc1, best_acc1)
    torch.cuda.empty_cache()
Ejemplo n.º 8
0
def train(opt):
    print("Random Seed: ", opt.seed)
    random.seed(opt.seed)
    torch.manual_seed(opt.seed)
    cudnn.benchmark = True
    device = 'cuda'
    batch_size = int(opt.batch_size)

    # tensorboard
    os.makedirs(os.path.join(opt.outf, 'runs'), exist_ok=True)
    writer = SummaryWriter(log_dir=os.path.join(opt.outf, 'runs'))

    # classifier follows architecture and initialization from attribute
    # classifiers in stylegan:
    # https://github.com/NVlabs/stylegan/blob/master/metrics/linear_separability.py#L136
    # https://github.com/NVlabs/stylegan/blob/master/training/networks_stylegan.py#L564
    net = attribute_classifier.D(3,
                                 resolution=256,
                                 fixed_size=True,
                                 use_mbstd=False)
    # use random normal bc wscale will rescale weights, see tf source
    # https://github.com/NVlabs/stylegan/blob/master/training/networks_stylegan.py#L148
    netinit.init_weights(net, init_type='normal', gain=1.)
    net = net.to(device)

    # losses + optimizers
    bce_loss = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(net.parameters(),
                           lr=opt.lr,
                           betas=(opt.beta1, 0.999))

    # datasets -- added random horizonal flipping for training
    train_transform = data.get_transform('celebahq', 'imtrain')
    train_dset = data.get_dataset('celebahq',
                                  'train',
                                  opt.attribute,
                                  load_w=False,
                                  transform=train_transform)
    print("Training transform:")
    print(train_transform)
    val_transform = data.get_transform('celebahq', 'imval')
    val_dset = data.get_dataset('celebahq',
                                'val',
                                opt.attribute,
                                load_w=False,
                                transform=val_transform)
    print("Validation transform:")
    print(val_transform)
    train_loader = DataLoader(train_dset,
                              batch_size=opt.batch_size,
                              shuffle=True,
                              pin_memory=False,
                              num_workers=opt.workers)
    val_loader = DataLoader(val_dset,
                            batch_size=opt.batch_size,
                            shuffle=False,
                            pin_memory=False,
                            num_workers=opt.workers)
    start_ep = 0
    best_val_acc = 0.0
    best_val_epoch = 0

    for epoch in pbar(range(start_ep, opt.niter + 1)):
        # average meter for train/val loss and train/val acc
        metrics = dict(train_loss=util.AverageMeter(),
                       val_loss=util.AverageMeter(),
                       train_acc=util.AverageMeter(),
                       val_acc=util.AverageMeter())

        # train loop
        for step, (im, label) in enumerate(pbar(train_loader)):
            im = im.cuda()
            label = label.cuda().float()

            net.zero_grad()
            logit, softmaxed = attribute_utils.get_softmaxed(net, im)
            # enforces that negative logit --> our label = 1
            loss = bce_loss(logit, 1 - label)
            predicted = (softmaxed > 0.5).long()
            correct = (predicted == label).float().mean().item()
            metrics['train_loss'].update(loss, n=len(label))
            metrics['train_acc'].update(correct, n=len(label))
            loss.backward()
            optimizer.step()
            if step % 200 == 0:
                pbar.print("%s: %0.2f" %
                           ('train loss', metrics['train_loss'].avg))
                pbar.print("%s: %0.2f" %
                           ('train acc', metrics['train_acc'].avg))

        # val loop
        net = net.eval()
        with torch.no_grad():
            for step, (im, label) in enumerate(pbar(val_loader)):
                im = im.cuda()
                label = label.cuda().float()
                logit, softmaxed = attribute_utils.get_softmaxed(net, im)
                predicted = (softmaxed > 0.5).long()
                correct = (predicted == label).float().mean().item()
                loss = bce_loss(logit, 1 - label)
                metrics['val_loss'].update(loss, n=len(label))
                metrics['val_acc'].update(correct, n=len(label))
        net = net.train()

        # send losses to tensorboard
        for k, v in metrics.items():
            pbar.print("Metrics at end of epoch")
            pbar.print("%s: %0.4f" % (k, v.avg))
            writer.add_scalar(k.replace('_', '/'), v.avg, epoch)

        # do checkpoint as latest
        util.make_checkpoint(net, optimizer, epoch, metrics['val_acc'].avg,
                             opt.outf, 'latest')

        if metrics['val_acc'].avg > best_val_acc:
            pbar.print("Updating best checkpoint at epoch %d" % epoch)
            pbar.print("Old Best Epoch %d Best Val %0.2f" %
                       (best_val_epoch, best_val_acc))
            # do checkpoint as best
            util.make_checkpoint(net, optimizer, epoch, metrics['val_acc'].avg,
                                 opt.outf, 'best')
            best_val_acc = metrics['val_acc'].avg
            best_val_epoch = epoch
            pbar.print("New Best Epoch %d Best Val %0.2f" %
                       (best_val_epoch, best_val_acc))
            with open("%s/best_val.txt" % opt.outf, "w") as f:
                f.write("Best Epoch %d Best Val %0.2f\n" %
                        (best_val_epoch, best_val_acc))

        if epoch >= best_val_epoch + 5:
            pbar.print("Exiting training")
            pbar.print("Best Val epoch %d" % best_val_epoch)
            pbar.print("Curr epoch %d" % epoch)
            break
def train(opt):
    print("Random Seed: ", opt.seed)
    random.seed(opt.seed)
    torch.manual_seed(opt.seed)
    cudnn.benchmark = True
    device = 'cuda'
    batch_size = int(opt.batch_size)
    domain = opt.domain

    # tensorboard
    os.makedirs(os.path.join(opt.outf, 'runs'), exist_ok=True)
    writer = SummaryWriter(log_dir=os.path.join(opt.outf, 'runs'))

    #  datasets
    train_transform = data.get_transform(domain, 'imtrain')
    train_dset = data.get_dataset(domain,
                                  'train',
                                  load_w=True,
                                  transform=train_transform)
    print("Training transform:")
    print(train_transform)
    val_transform = data.get_transform(domain, 'imval')
    val_dset = data.get_dataset(domain,
                                'val',
                                load_w=True,
                                transform=val_transform)
    print("Validation transform:")
    print(val_transform)
    train_loader = DataLoader(train_dset,
                              batch_size=opt.batch_size,
                              shuffle=True,
                              pin_memory=False,
                              num_workers=opt.workers)
    val_loader = DataLoader(val_dset,
                            batch_size=opt.batch_size,
                            shuffle=False,
                            pin_memory=False,
                            num_workers=opt.workers)

    # classifier: resnet18 model
    net = torchvision.models.resnet18(
        num_classes=len(train_dset.coarse_labels))
    # load the model weights to finetune from
    ckpt_path = ('results/classifiers/%s/%s/net_best.pth' %
                 (opt.domain, opt.finetune_from))
    print("Finetuning model from %s" % ckpt_path)
    ckpt = torch.load(ckpt_path)
    state_dict = ckpt['state_dict']
    net.load_state_dict(state_dict)
    net = net.to(device)

    # losses + optimizers
    criterion = nn.CrossEntropyLoss().to(device)  # loss(output, target)
    optimizer = optim.Adam(net.parameters(),
                           lr=opt.lr,
                           betas=(opt.beta1, 0.999))
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     mode='max',
                                                     patience=10,
                                                     min_lr=1e-6,
                                                     verbose=True)

    start_ep = 0
    best_val_acc = 0.0
    best_val_epoch = 0

    # val tensor transform does not flip, train tensor transform
    # mimics random resized crop and random flip
    val_tensor_transform = data.get_transform(domain, 'tensorbase')
    train_tensor_transform = data.get_transform(domain, 'tensormixedtrain')

    # load GAN
    generator = domain_generator.define_generator('stylegan2', domain)

    for epoch in pbar(range(start_ep, opt.niter + 1)):
        # average meter for train/val loss and train/val acc
        metrics = dict(train_loss=util.AverageMeter(),
                       val_loss=util.AverageMeter(),
                       train_acc=util.AverageMeter(),
                       val_acc=util.AverageMeter())

        # train loop
        for step, item in enumerate(pbar(train_loader)):
            im_orig = item[0].cuda()
            opt_w = item[1].cuda()
            label = item[2].cuda()
            with torch.no_grad():
                if opt.perturb_type == 'stylemix':
                    seed = epoch * len(train_loader) + step
                    mix_latent = generator.seed2w(n=opt_w.shape[0], seed=seed)
                    generated_im = generator.perturb_stylemix(
                        opt_w,
                        opt.perturb_layer,
                        mix_latent,
                        n=opt_w.shape[0],
                        is_eval=False)
                elif opt.perturb_type == 'isotropic':
                    # used minimum value for isotropic setting
                    eps = np.min(
                        generator.perturb_settings['isotropic_eps_%s' %
                                                   opt.perturb_layer])
                    generated_im = generator.perturb_isotropic(
                        opt_w,
                        opt.perturb_layer,
                        eps=eps,
                        n=opt_w.shape[0],
                        is_eval=False)
                elif opt.perturb_type == 'pca':
                    # used median value for pca setting
                    eps = np.median(generator.perturb_settings['pca_eps'])
                    generated_im = generator.perturb_pca(opt_w,
                                                         opt.perturb_layer,
                                                         eps=eps,
                                                         n=opt_w.shape[0],
                                                         is_eval=False)
                else:
                    generated_im = generator.decode(opt_w)

                generated_im = train_tensor_transform(generated_im)
                # sanity check that the shapes match
                assert (generated_im.shape == im_orig.shape)

            # with 50% chance, take the original image for
            # training the classifier
            im = im_orig if torch.rand(1) > 0.5 else generated_im

            net.zero_grad()
            output = net(im)
            loss = criterion(output, label)
            accs, _ = accuracy(output, label, topk=(1, ))
            metrics['train_loss'].update(loss, n=len(label))
            metrics['train_acc'].update(accs[0], n=len(label))
            loss.backward()
            optimizer.step()
            if step % 200 == 0:
                pbar.print("%s: %0.6f" %
                           ('train loss', metrics['train_loss'].avg))
                pbar.print("%s: %0.6f" %
                           ('train acc', metrics['train_acc'].avg))

        # val loop
        net = net.eval()
        with torch.no_grad():
            for step, item in enumerate(pbar(val_loader)):
                im_orig = item[0].cuda()
                opt_w = item[1].cuda()
                label = item[2].cuda()
                with torch.no_grad():
                    # evaluate on the generated image
                    im = generator.decode(opt_w)
                    im = val_tensor_transform(im)
                assert (im.shape == im_orig.shape)
                output = net(im)
                loss = criterion(output, label)
                accs, _ = accuracy(output, label, topk=(1, ))
                metrics['val_loss'].update(loss, n=len(label))
                metrics['val_acc'].update(accs[0], n=len(label))
        net = net.train()

        scheduler.step(metrics['val_acc'].avg)

        # send losses to tensorboard
        for k, v in metrics.items():
            pbar.print("Metrics at end of epoch")
            pbar.print("%s: %0.4f" % (k, v.avg))
            writer.add_scalar(k.replace('_', '/'), v.avg, epoch)
        pbar.print("Learning rate: %0.6f" % optimizer.param_groups[0]['lr'])

        # do checkpoint as latest
        util.make_checkpoint(net, optimizer, epoch, metrics['val_acc'].avg,
                             opt.outf, 'latest')

        if metrics['val_acc'].avg > best_val_acc:
            pbar.print("Updating best checkpoint at epoch %d" % epoch)
            pbar.print("Old Best Epoch %d Best Val %0.6f" %
                       (best_val_epoch, best_val_acc))
            # do checkpoint as best
            util.make_checkpoint(net, optimizer, epoch, metrics['val_acc'].avg,
                                 opt.outf, 'best')
            best_val_acc = metrics['val_acc'].avg
            best_val_epoch = epoch
            pbar.print("New Best Epoch %d Best Val %0.6f" %
                       (best_val_epoch, best_val_acc))
            with open("%s/best_val.txt" % opt.outf, "w") as f:
                f.write("Best Epoch %d Best Val %0.6f\n" %
                        (best_val_epoch, best_val_acc))

        if (float(optimizer.param_groups[0]['lr']) <= 1e-6
                and epoch >= best_val_epoch + 50):
            pbar.print("Exiting training")
            pbar.print("Best Val epoch %d" % best_val_epoch)
            pbar.print("Curr epoch %d" % epoch)
            break
Ejemplo n.º 10
0
    def test(self, epoch, topk=(1, )):

        losses = []
        accuracy = []
        top5_accuracy = []
        fusion_accuracy = utils.AverageMeter()
        leader_accuracy = utils.AverageMeter()
        average_accuracy = utils.AverageMeter()
        ensemble_accuracy = utils.AverageMeter()
        self.fusion_module.eval()
        self.leader_model.eval()
        for i in range(self.model_num):
            self.models[i].eval()
            accuracy.append(utils.AverageMeter())
            top5_accuracy.append(utils.AverageMeter())
        accuracy.append(fusion_accuracy)
        accuracy.append(leader_accuracy)

        start_time = time.time()
        with torch.no_grad():
            for batch_idx, (inputs, labels) in enumerate(self.testLoader):
                inputs, labels = inputs.to(self.device), labels.to(self.device)

                outputs = []
                fusion_module_inputs = []
                leader_output, _ = self.leader_model(inputs)
                for i in range(self.model_num):
                    outputs.append(self.models[i](inputs))
                    fusion_module_inputs.append(
                        self.cat_feature_maps(
                            self.models[i].module.total_feature_maps,
                            self.models[i].module.extract_layers)[-1].detach())
                fusion_module_inputs = torch.cat(fusion_module_inputs, dim=1)
                fusion_output = self.fusion_module(fusion_module_inputs)

                # measure accuracy and record loss
                for i in range(self.model_num):
                    prec = utils.accuracy(outputs[i].data,
                                          labels.data,
                                          topk=topk)
                    accuracy[i].update(prec[0], inputs.size(0))
                    if len(topk) == 2:
                        top5_accuracy[i].update(prec[1], inputs.size(0))

                fusion_prec = utils.accuracy(fusion_output,
                                             labels.data,
                                             topk=topk)
                fusion_accuracy.update(fusion_prec[0], inputs.size(0))

                leader_prec = utils.accuracy(leader_output,
                                             labels.data,
                                             topk=topk)
                leader_accuracy.update(leader_prec[0], inputs.size(0))

                average_prec = utils.average_accuracy(outputs,
                                                      labels.data,
                                                      topk=topk)
                ensemble_prec = utils.ensemble_accuracy(outputs,
                                                        labels.data,
                                                        topk=topk)

                average_accuracy.update(average_prec[0], inputs.size(0))
                ensemble_accuracy.update(ensemble_prec[0], inputs.size(0))

            current_time = time.time()

            msg = 'Epoch[{}]\tTime {:.2f}s\t'.format(epoch,
                                                     current_time - start_time)

            for i in range(self.model_num):
                msg += 'Model[{}]:\tAccuracy {:.2f}%\t'.format(
                    i, float(accuracy[i].avg))
            msg += 'Model[{}]:\tAccuracy {:.2f}%\t'.format(
                'Fusion', float(fusion_accuracy.avg))
            msg += 'Model[{}]:\tAccuracy {:.2f}%\t'.format(
                'Leader', float(leader_accuracy.avg))

            msg += 'Average Acc:{:.2f}\tEnsemble Acc:{:.2f}'.format(
                float(average_accuracy.avg), float(ensemble_accuracy.avg))

            self.logger.info(msg + '\n')

        return losses, accuracy, top5_accuracy, average_accuracy, ensemble_accuracy
Ejemplo n.º 11
0
    def train_with_test(self, epoch, topk=(1, )):

        accuracy = []
        losses = []
        ce_losses = []
        dml_losses = []
        diversity_losses = []
        self_distillation_feature_losses = []
        self_distillation_attention_losses = []
        self_distillation_losses = []

        fusion_accuracy = utils.AverageMeter()
        fusion_ce_loss = utils.AverageMeter()
        fusion_ensemble_loss = utils.AverageMeter()
        fusion_loss = utils.AverageMeter()

        leader_accuracy = utils.AverageMeter()
        leader_ce_loss = utils.AverageMeter()
        leader_ensemble_loss = utils.AverageMeter()
        leader_self_distillation_feature_loss = utils.AverageMeter()
        leader_self_distillation_attention_loss = utils.AverageMeter()
        leader_self_distillation_loss = utils.AverageMeter()
        leader_fusion_loss = utils.AverageMeter()
        leader_trans_fusion_loss = utils.AverageMeter()
        leader_loss = utils.AverageMeter()

        average_accuracy = utils.AverageMeter()
        ensemble_accuracy = utils.AverageMeter()

        self.fusion_module.train()
        self.leader_model.train()
        for i in range(self.model_num):
            self.models[i].train()
            losses.append(utils.AverageMeter())
            ce_losses.append(utils.AverageMeter())
            dml_losses.append(utils.AverageMeter())
            diversity_losses.append(utils.AverageMeter())
            self_distillation_feature_losses.append(utils.AverageMeter())
            self_distillation_attention_losses.append(utils.AverageMeter())
            self_distillation_losses.append(utils.AverageMeter())
            accuracy.append(utils.AverageMeter())

        dataset_size = len(self.trainLoader.dataset)
        print_freq = dataset_size // self.opt.train_batch_size // 10
        start_time = time.time()
        epoch_iter = 0

        for batch, (inputs, labels) in enumerate(self.trainLoader):

            inputs, labels = inputs.to(self.device), labels.to(self.device)

            self.adjust_learning_rates(epoch, batch,
                                       dataset_size // self.train_batch_size)

            epoch_iter += self.train_batch_size

            ensemble_output = 0.0
            outputs = []
            total_feature_maps = []
            fusion_module_inputs = []
            leader_output, leader_trans_fusion_output = self.leader_model(
                inputs)
            for i in range(self.model_num):
                outputs.append(self.models[i](inputs))
                ensemble_output += outputs[-1]

                total_feature_maps.append(
                    self.cat_feature_maps(
                        self.models[i].module.total_feature_maps,
                        self.models[i].module.extract_layers))
                fusion_module_inputs.append(
                    total_feature_maps[-1][-1].detach())
            fusion_module_inputs = torch.cat(fusion_module_inputs, dim=1)
            fusion_output = self.fusion_module(fusion_module_inputs)

            ensemble_output = ensemble_output / self.model_num

            # backward models
            for i in range(self.model_num):

                loss_ce = self.criterion_CE(outputs[i], labels)
                loss_dml = 0.0

                for j in range(self.model_num):
                    if i != j:
                        loss_dml += self.criterion_KL(
                            F.log_softmax(outputs[i] / self.temperature,
                                          dim=1),
                            F.softmax(outputs[j].detach() / self.temperature,
                                      dim=1))

                if i != 0 and self.lambda_diversity > 0:
                    current_attention_map = total_feature_maps[i][-1].pow(
                        2).mean(1, keepdim=True)
                    other_attention_map = total_feature_maps[
                        i - 1][-1].detach().pow(2).mean(1, keepdim=True)
                    loss_diversity = self.lambda_diversity * self.diversity_loss(
                        current_attention_map, other_attention_map)
                    loss_self_distllation = self.lambda_diversity * \
                                            self.self_distillation_loss(self.sd_models[i - 1],
                                                                        total_feature_maps[i],
                                                                        input_feature_map=self.diversity_target(
                                                                            total_feature_maps[i - 1][-1].detach()))

                else:
                    loss_diversity = 0.0
                    loss_self_distllation = 0.0
                loss_dml = (self.temperature**
                            2) * loss_dml / (self.model_num - 1)
                loss = loss_ce + loss_dml + loss_diversity + loss_self_distllation

                # measure accuracy and record loss
                prec = utils.accuracy(outputs[i].data, labels.data, topk=topk)
                losses[i].update(loss.item(), inputs.size(0))
                ce_losses[i].update(loss_ce.item(), inputs.size(0))
                dml_losses[i].update(loss_dml, inputs.size(0))
                diversity_losses[i].update(loss_diversity, inputs.size(0))
                self_distillation_losses[i].update(loss_self_distllation,
                                                   inputs.size(0))
                accuracy[i].update(prec[0], inputs.size(0))

                self.optimizers[i].zero_grad()
                loss.backward()
                self.optimizers[i].step()

            # backward fusion module
            loss_fusion_ce = self.criterion_CE(fusion_output, labels)
            loss_fusion_ensemble = (self.temperature**2) * self.criterion_KL(
                F.log_softmax(fusion_output / self.temperature, dim=1),
                F.softmax(ensemble_output.detach() / self.temperature, dim=1))
            loss_fusion = loss_fusion_ce + loss_fusion_ensemble
            self.fusion_optimizer.zero_grad()
            loss_fusion.backward()
            self.fusion_optimizer.step()

            fusion_ce_loss.update(loss_fusion_ce.item(), inputs.size(0))
            fusion_ensemble_loss.update(loss_fusion_ensemble.item(),
                                        inputs.size(0))
            fusion_loss.update(loss_fusion.item(), inputs.size(0))
            fusion_prec = utils.accuracy(fusion_output, labels.data, topk=topk)
            fusion_accuracy.update(fusion_prec[0], inputs.size(0))

            # backward leader models
            leader_feature_maps = self.cat_feature_maps(
                self.leader_model.module.total_feature_maps,
                self.leader_model.module.extract_layers)
            fusion_feature_maps = self.cat_feature_maps(
                self.fusion_module.module.total_feature_maps,
                self.fusion_module.module.extract_layers)
            loss_leader_ce = self.criterion_CE(leader_output, labels)
            loss_leader_ensemble = (self.temperature**2) * self.criterion_KL(
                F.log_softmax(leader_output / self.temperature, dim=1),
                F.softmax(fusion_output.detach() / self.temperature, dim=1))
            loss_leader_fusion = self.lambda_fusion * self.fusion_loss(
                leader_feature_maps[-1].pow(2).mean(1, keepdim=True),
                fusion_feature_maps[-1].detach().pow(2).mean(1, keepdim=True))
            loss_leader_trans_fusion = self.lambda_fusion * \
                                       self.fusion_loss(leader_trans_fusion_output.pow(2).mean(1, keepdim=True),
                                                           fusion_module_inputs.pow(2).mean(1, keepdim=True))

            loss_leader_self_distillation = self.lambda_fusion * \
                                            self.self_distillation_loss(self.sd_leader_model, leader_feature_maps,
                                                                        input_feature_map=fusion_feature_maps[-1].detach())
            loss_leader = loss_leader_ce + loss_leader_ensemble + loss_leader_fusion + loss_leader_trans_fusion + loss_leader_self_distillation

            self.leader_optimizer.zero_grad()
            loss_leader.backward()
            self.leader_optimizer.step()

            leader_ce_loss.update(loss_leader_ce.item(), inputs.size(0))
            leader_ensemble_loss.update(loss_leader_ensemble.item(),
                                        inputs.size(0))
            leader_fusion_loss.update(loss_leader_fusion, inputs.size(0))
            leader_trans_fusion_loss.update(loss_leader_trans_fusion,
                                            inputs.size(0))
            leader_self_distillation_loss.update(loss_leader_self_distillation,
                                                 inputs.size(0))
            leader_loss.update(loss_leader.item(), inputs.size(0))
            leader_prec = utils.accuracy(leader_output, labels.data, topk=topk)
            leader_accuracy.update(leader_prec[0], inputs.size(0))

            # update self distillation model after all models updated
            for i in range(1, self.model_num):
                loss_self_distillation_feature, loss_self_distillation_attention = \
                    self.train_self_distillation_model(self.sd_models[i - 1],
                                                       self.sd_optimizers[i - 1],
                                                       target_feature_maps=total_feature_maps[i])
                self_distillation_feature_losses[i].update(
                    loss_self_distillation_feature, inputs.size(0))
                self_distillation_attention_losses[i].update(
                    loss_self_distillation_attention, inputs.size(0))

            loss_leader_self_distillation_feature, loss_leader_self_distillation_attention = \
                self.train_self_distillation_model(self.sd_leader_model,
                                                   self.sd_leader_optimizer,
                                                   target_feature_maps=leader_feature_maps)
            leader_self_distillation_feature_loss.update(
                loss_leader_self_distillation_feature, inputs.size(0))
            leader_self_distillation_attention_loss.update(
                loss_leader_self_distillation_attention, inputs.size(0))

            average_prec = utils.average_accuracy(outputs,
                                                  labels.data,
                                                  topk=topk)
            ensemble_prec = utils.ensemble_accuracy(outputs,
                                                    labels.data,
                                                    topk=topk)

            average_accuracy.update(average_prec[0], inputs.size(0))
            ensemble_accuracy.update(ensemble_prec[0], inputs.size(0))

            if batch % print_freq == 0 and batch != 0:
                current_time = time.time()
                cost_time = current_time - start_time

                msg = 'Epoch[{}] ({}/{})\tTime {:.2f}s\t'.format(
                    epoch, batch * self.train_batch_size, dataset_size,
                    cost_time)
                for i in range(self.model_num):

                    msg += '|Model[{}]: Loss:{:.4f}\t' \
                           'CE Loss:{:.4f}\tDML Loss:{:.4f}\t' \
                           'Diversity Loss:{:.4f}\tSD Feature:{:.4f}' \
                           'SD Attention:{:.4f}\tSelf Distillation Loss:{:.4f}\t' \
                           'Accuracy {:.2f}%\t'.format(
                        i, float(losses[i].avg), float(ce_losses[i].avg), float(dml_losses[i].avg),
                        float(diversity_losses[i].avg), float(self_distillation_feature_losses[i].avg),
                        float(self_distillation_attention_losses[i].avg), float(self_distillation_losses[i].avg),
                        float(accuracy[i].avg))
                msg += '|Model[{}]: Loss:{:.4f}\t' \
                       'CE Loss:{:.4f}\tKL Loss:{:.4f}\t' \
                       'Accuracy {:.2f}%\t'.format(
                    'fusion', float(fusion_loss.avg), float(fusion_ce_loss.avg), float(fusion_ensemble_loss.avg),
                    float(fusion_accuracy.avg))
                msg += '|Model[{}]: Loss:{:.4f}\t' \
                       'CE Loss:{:.4f}\tEnsemble Loss:{:.4f}\t' \
                       'Fusion Loss:{:.4f}\tTrans Fusion Loss:{:.4f}\t' \
                       'SD Feature:{:.4f}\tSD Attention:{:.4f}\t' \
                       'Self Distillation Loss:{:.4f}\tAccuracy {:.2f}%\t'.format(
                    'leader', float(leader_loss.avg), float(leader_ce_loss.avg),
                    float(leader_ensemble_loss.avg), float(leader_fusion_loss.avg), float(leader_trans_fusion_loss.avg),
                    float(leader_self_distillation_feature_loss.avg),
                    float(leader_self_distillation_attention_loss.avg),
                    float(leader_self_distillation_loss.avg), float(leader_accuracy.avg))

                msg += '|Average Acc:{:.2f}\tEnsemble Acc:{:.2f}'.format(
                    float(average_accuracy.avg), float(ensemble_accuracy.avg))
                self.logger.info(msg)

                start_time = current_time
def train(opt):
    print("Random Seed: ", opt.seed)
    random.seed(opt.seed)
    torch.manual_seed(opt.seed)
    cudnn.benchmark = True
    device = 'cuda'
    batch_size = int(opt.batch_size)
    domain = opt.domain

    # tensorboard
    os.makedirs(os.path.join(opt.outf, 'runs'), exist_ok=True)
    writer = SummaryWriter(log_dir=os.path.join(opt.outf, 'runs'))

    #  datasets
    train_transform = data.get_transform(domain, 'imtrain')
    train_dset = data.get_dataset(domain,
                                  'train',
                                  load_w=False,
                                  transform=train_transform)
    print("Training transform:")
    print(train_transform)
    val_transform = data.get_transform(domain, 'imval')
    val_dset = data.get_dataset(domain,
                                'val',
                                load_w=False,
                                transform=val_transform)
    print("Validation transform:")
    print(val_transform)
    train_loader = DataLoader(train_dset,
                              batch_size=opt.batch_size,
                              shuffle=True,
                              pin_memory=False,
                              num_workers=opt.workers)
    val_loader = DataLoader(val_dset,
                            batch_size=opt.batch_size,
                            shuffle=False,
                            pin_memory=False,
                            num_workers=opt.workers)

    # classifier: resnet18 model
    net = torchvision.models.resnet18(
        num_classes=len(train_dset.coarse_labels))
    if not opt.train_from_scratch:
        state_dict = torchvision.models.utils.load_state_dict_from_url(
            torchvision.models.resnet.model_urls['resnet18'])
        del state_dict['fc.weight']
        del state_dict['fc.bias']
        net.load_state_dict(state_dict, strict=False)
    net = net.to(device)

    # losses + optimizers + scheduler
    # use smaller learning rate for the feature layers if initialized
    # with imagenet pretrained weights
    criterion = nn.CrossEntropyLoss().to(device)  # loss(output, target)
    fc_params = [k[1] for k in net.named_parameters() if k[0].startswith('fc')]
    feat_params = [
        k[1] for k in net.named_parameters() if not k[0].startswith('fc')
    ]
    feature_backbone_lr = opt.lr if opt.train_from_scratch else 0.1 * opt.lr
    print("Initial learning rate for feature backbone: %0.4f" %
          feature_backbone_lr)
    print("Initial learning rate for FC layer: %0.4f" % opt.lr)
    optimizer = optim.Adam([{
        'params': fc_params
    }, {
        'params': feat_params,
        'lr': feature_backbone_lr
    }],
                           lr=opt.lr,
                           betas=(opt.beta1, 0.999))
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     mode='max',
                                                     patience=10,
                                                     min_lr=1e-6,
                                                     verbose=True)

    start_ep = 0
    best_val_acc = 0.0
    best_val_epoch = 0

    for epoch in pbar(range(start_ep, opt.niter + 1)):
        # average meter for train/val loss and train/val acc
        metrics = dict(train_loss=util.AverageMeter(),
                       val_loss=util.AverageMeter(),
                       train_acc=util.AverageMeter(),
                       val_acc=util.AverageMeter())

        # train loop
        for step, item in enumerate(pbar(train_loader)):
            im = item[0].cuda()
            label = item[1].cuda()

            net.zero_grad()
            output = net(im)
            loss = criterion(output, label)

            accs, _ = accuracy(output, label, topk=(1, ))
            metrics['train_loss'].update(loss, n=len(label))
            metrics['train_acc'].update(accs[0], n=len(label))
            loss.backward()
            optimizer.step()
        pbar.print("%s: %0.2f" % ('train loss', metrics['train_loss'].avg))
        pbar.print("%s: %0.2f" % ('train acc', metrics['train_acc'].avg))

        # val loop
        net = net.eval()
        with torch.no_grad():
            for step, item in enumerate(pbar(val_loader)):
                im = item[0].cuda()
                label = item[1].cuda()
                output = net(im)
                loss = criterion(output, label)
                accs, _ = accuracy(output, label, topk=(1, ))
                metrics['val_loss'].update(loss, n=len(label))
                metrics['val_acc'].update(accs[0], n=len(label))
        net = net.train()

        # update scheduler
        scheduler.step(metrics['val_acc'].avg)

        # send losses to tensorboard
        for k, v in metrics.items():
            pbar.print("Metrics at end of epoch")
            pbar.print("%s: %0.4f" % (k, v.avg))
            writer.add_scalar(k.replace('_', '/'), v.avg, epoch)
        pbar.print("Learning rate: %0.6f" % optimizer.param_groups[0]['lr'])

        # do checkpoint as latest
        util.make_checkpoint(net, optimizer, epoch, metrics['val_acc'].avg,
                             opt.outf, 'latest')

        if metrics['val_acc'].avg > best_val_acc:
            pbar.print("Updating best checkpoint at epoch %d" % epoch)
            pbar.print("Old Best Epoch %d Best Val %0.2f" %
                       (best_val_epoch, best_val_acc))
            # do checkpoint as best
            util.make_checkpoint(net, optimizer, epoch, metrics['val_acc'].avg,
                                 opt.outf, 'best')
            best_val_acc = metrics['val_acc'].avg
            best_val_epoch = epoch
            pbar.print("New Best Epoch %d Best Val %0.2f" %
                       (best_val_epoch, best_val_acc))
            with open("%s/best_val.txt" % opt.outf, "w") as f:
                f.write("Best Epoch %d Best Val %0.2f\n" %
                        (best_val_epoch, best_val_acc))

        # terminate training if reached min LR and best validation is
        # not improving
        if (float(optimizer.param_groups[0]['lr']) <= 1e-6
                and epoch >= best_val_epoch + 20):
            pbar.print("Exiting training")
            pbar.print("Best Val epoch %d" % best_val_epoch)
            pbar.print("Curr epoch %d" % epoch)
            break
Ejemplo n.º 13
0
def run_nn(cfg,
           mode,
           model,
           loader,
           criterion=None,
           optimizer=None,
           scheduler=None,
           apex=None,
           epoch=None):

    if mode in ['train']:
        model.train()
    elif mode in [
            'valid',
    ]:
        model.eval()
    else:
        raise

    losses = util.AverageMeter()
    scores = util.AverageMeter()

    ids_all = []
    targets_all = []
    outputs_all = []

    # log.info(f'len(loader): {len(loader)}')
    for i, (inputs, targets, regrs) in enumerate(loader):
        # log.info(f'i: {i}')
        # zero out gradients so we can accumulate new ones over batches
        if mode in ['train']:
            optimizer.zero_grad()

        # move data to device
        inputs = inputs.to(device, dtype=torch.float)
        targets = targets.to(device, dtype=torch.float)
        regrs = regrs.to(device, dtype=torch.float)

        # log.info(f'inputs.shape: {inputs.shape}')
        # log.info(f'targets.shape: {targets.shape}')
        # log.info(f'regrs.shape: {regrs.shape}')

        outputs = model(inputs)
        # log.info(f'outputs.shape: {outputs.shape}')

        # both train mode and valid mode
        if mode in ['train', 'valid']:
            with torch.set_grad_enabled(mode == 'train'):
                loss = criterion(outputs, targets, regrs)
                # loss = criterion(torch.sigmoid(outputs), targets)

                loss = loss / cfg.n_grad_acc
                with torch.no_grad():
                    losses.update(loss.item())

        # train mode
        if mode in ['train']:
            if apex:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()  # accumulate loss

            if (i + 1) % cfg.n_grad_acc == 0 or (i + 1) == len(loader):
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()  # update
                optimizer.zero_grad()  # flush

        # # compute metrics
        # score = metrics.dice_score(outputs, targets)
        # with torch.no_grad():
        #     scores.update(score.item())

    result = {
        'loss': losses.avg,
        'score': scores.avg,
        'ids': ids_all,
        'targets': np.array(targets_all),
        'outputs': np.array(outputs_all),
    }

    return result