Example #1
0
  def get_inception_metrics(sample_func, eval_iter, *args, **kwargs):

    FID, IS_mean, IS_std = FID_IS(sample_func=sample_func)
    logger.info(f'\n\teval_iter {eval_iter}: '
                f'IS_mean_tf:{IS_mean:.3f} +- {IS_std:.3f}\n\tFID_tf: {FID:.3f}')
    if not math.isnan(IS_mean):
      dict_data = (dict(FID_tf=FID, IS_mean_tf=IS_mean, IS_std_tf=IS_std))
      summary_dict2txtfig(dict_data=dict_data, prefix='evaltf', step=eval_iter, textlogger=textlogger)

    return IS_mean, IS_std, FID
Example #2
0
def train_epoch(train_loader, model, loss_fun, optimizer, train_meter,
                cur_epoch):
    """Performs one epoch of training."""
    # Shuffle the data
    loader.shuffle(train_loader, cur_epoch)
    # Update the learning rate
    lr = optim.get_epoch_lr(cur_epoch)
    optim.set_lr(optimizer, lr)
    # Enable training mode
    model.train()
    train_meter.reset()
    train_meter.iter_tic()
    for cur_iter, (inputs, labels) in enumerate(train_loader):
        # Transfer the data to the current GPU device
        inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
        # Perform the forward pass
        preds = model(inputs)
        # Compute the loss
        loss = loss_fun(preds, labels)
        # Perform the backward pass
        optimizer.zero_grad()
        loss.backward()
        # Update the parameters
        optimizer.step()
        # Compute the errors
        top1_err, top5_err = meters.topk_errors(preds, labels, [1, 5])
        # Combine the stats across the GPUs (no reduction if 1 GPU used)
        loss, top1_err, top5_err = dist.scaled_all_reduce(
            [loss, top1_err, top5_err])
        # Copy the stats from GPU to CPU (sync point)
        loss, top1_err, top5_err = loss.item(), top1_err.item(), top5_err.item(
        )
        train_meter.iter_toc()
        # Update and log stats
        mb_size = inputs.size(0) * cfg.NUM_GPUS
        train_meter.update_stats(top1_err, top5_err, loss, lr, mb_size)
        train_meter.log_iter_stats(cur_epoch, cur_iter)
        train_meter.iter_tic()
    # Log epoch stats
    train_meter.log_epoch_stats(cur_epoch)
    print(f'{cfg.OUT_DIR}')

    if not hasattr(cfg, 'search_epoch'):
        stats = train_meter.get_epoch_stats(cur_epoch)
        stats = {k: v for k, v in stats.items() if isinstance(v, (int, float))}
        summary_dict2txtfig(stats,
                            prefix='train',
                            step=cur_epoch,
                            textlogger=textlogger,
                            save_fig_sec=60)
    def get_inception_metrics(sample,
                              step,
                              num_inception_images,
                              num_splits=10,
                              prints=True,
                              use_torch=True):
        if prints:
            print('Gathering activations...')
        pool, logits, labels = accumulate_inception_activations(
            sample, net, num_inception_images)
        if prints:
            print('Calculating Inception Score...')
            print(f'Num images: {len(pool)}')
        IS_mean, IS_std = calculate_inception_score(logits.cpu().numpy(),
                                                    num_splits)
        if no_fid:
            FID = 9999.0
        else:
            if prints:
                print('Calculating means and covariances...')
            if use_torch:
                mu, sigma = torch.mean(pool, 0), torch_cov(pool, rowvar=False)
            else:
                mu, sigma = np.mean(pool.cpu().numpy(),
                                    axis=0), np.cov(pool.cpu().numpy(),
                                                    rowvar=False)
            if prints:
                print('Covariances calculated, getting FID...')
            if use_torch:
                FID = torch_calculate_frechet_distance(
                    mu, sigma,
                    torch.tensor(data_mu).float().cuda(),
                    torch.tensor(data_sigma).float().cuda())
                FID = float(FID.cpu().numpy())
            else:
                FID = numpy_calculate_frechet_distance(mu, sigma, data_mu,
                                                       data_sigma)
        # Delete mu, sigma, pool, logits, and labels, just in case
        del mu, sigma, pool, logits, labels

        if not math.isnan(FID) and not math.isnan(IS_mean):
            dict_data = (dict(FID_tf=FID, IS_mean_tf=IS_mean,
                              IS_std_tf=IS_std))
            summary_dict2txtfig(dict_data=dict_data,
                                prefix='evaltf',
                                step=step,
                                textlogger=textlogger)
        return IS_mean, IS_std, FID
Example #4
0
def test_epoch(test_loader, model, test_meter, cur_epoch):
    """Evaluates the model on the test set."""
    # Enable eval mode
    model.eval()
    test_meter.reset()
    test_meter.iter_tic()
    for cur_iter, (inputs, labels) in enumerate(test_loader):
        # Transfer the data to the current GPU device
        inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
        # Compute the predictions
        preds = model(inputs)
        # Compute the errors
        top1_err, top5_err = meters.topk_errors(preds, labels, [1, 5])
        # Combine the errors across the GPUs  (no reduction if 1 GPU used)
        top1_err, top5_err = dist.scaled_all_reduce([top1_err, top5_err])
        # Copy the errors from GPU to CPU (sync point)
        top1_err, top5_err = top1_err.item(), top5_err.item()
        test_meter.iter_toc()
        # Update and log stats
        test_meter.update_stats(top1_err, top5_err,
                                inputs.size(0) * cfg.NUM_GPUS)
        test_meter.log_iter_stats(cur_epoch, cur_iter)
        test_meter.iter_tic()
    # Log epoch stats
    test_meter.log_epoch_stats(cur_epoch)

    stats = test_meter.get_epoch_stats(cur_epoch)
    if not hasattr(cfg, 'search_epoch'):
        stats = {k: v for k, v in stats.items() if isinstance(v, (int, float))}
        summary_dict2txtfig(stats,
                            prefix='test',
                            step=cur_epoch,
                            textlogger=textlogger,
                            save_fig_sec=60)

    return stats
Example #5
0
def main():
    logger = logging.getLogger('tl')
    args = get_args()
    # CUDA setting
    if not torch.cuda.is_available():
        raise ValueError("Should buy GPU!")
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    device = torch.device('cuda')
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    torch.backends.cudnn.benchmark = True

    def _rescale(img):
        return img * 2.0 - 1.0

    def _noise_adder(img):
        return torch.empty_like(img, dtype=img.dtype).uniform_(0.0,
                                                               1 / 128.0) + img

    # dataset
    dataset_cfg = global_cfg.get('dataset_cfg', {})
    dataset_module = dataset_cfg.get('dataset_module', 'datasets.ImageFolder')

    train_dataset = eval(dataset_module)(os.path.join(args.data_root),
                                         transform=transforms.Compose([
                                             transforms.ToTensor(),
                                             _rescale,
                                             _noise_adder,
                                         ]),
                                         **dataset_cfg.get(
                                             'dataset_kwargs', {}))
    train_loader = iter(
        data.DataLoader(train_dataset,
                        args.batch_size,
                        sampler=InfiniteSamplerWrapper(train_dataset),
                        num_workers=args.num_workers,
                        pin_memory=False))
    if args.calc_FID:
        eval_dataset = datasets.ImageFolder(
            os.path.join(args.data_root, 'val'),
            transforms.Compose([
                transforms.ToTensor(),
                _rescale,
            ]))
        eval_loader = iter(
            data.DataLoader(eval_dataset,
                            args.batch_size,
                            sampler=InfiniteSamplerWrapper(eval_dataset),
                            num_workers=args.num_workers,
                            pin_memory=True))
    else:
        eval_loader = None
    num_classes = len(train_dataset.classes)
    print(' prepared datasets...')
    print(' Number of training images: {}'.format(len(train_dataset)))
    # Prepare directories.
    args.num_classes = num_classes
    args, writer = prepare_results_dir(args)
    # initialize models.
    _n_cls = num_classes if args.cGAN else 0

    gen_module = getattr(
        global_cfg.generator, 'module',
        'pytorch_sngan_projection_lib.models.generators.resnet64')
    model_module = importlib.import_module(gen_module)

    gen = model_module.ResNetGenerator(
        args.gen_num_features,
        args.gen_dim_z,
        args.gen_bottom_width,
        activation=F.relu,
        num_classes=_n_cls,
        distribution=args.gen_distribution).to(device)

    if args.dis_arch_concat:
        dis = SNResNetConcatDiscriminator(args.dis_num_features, _n_cls,
                                          F.relu, args.dis_emb).to(device)
    else:
        dis = SNResNetProjectionDiscriminator(args.dis_num_features, _n_cls,
                                              F.relu).to(device)
    inception_model = inception.InceptionV3().to(
        device) if args.calc_FID else None

    opt_gen = optim.Adam(gen.parameters(), args.lr, (args.beta1, args.beta2))
    opt_dis = optim.Adam(dis.parameters(), args.lr, (args.beta1, args.beta2))

    # gen_criterion = getattr(L, 'gen_{}'.format(args.loss_type))
    # dis_criterion = getattr(L, 'dis_{}'.format(args.loss_type))
    gen_criterion = L.GenLoss(args.loss_type, args.relativistic_loss)
    dis_criterion = L.DisLoss(args.loss_type, args.relativistic_loss)
    print(' Initialized models...\n')

    if args.args_path is not None:
        print(' Load weights...\n')
        prev_args, gen, opt_gen, dis, opt_dis = utils.resume_from_args(
            args.args_path, args.gen_ckpt_path, args.dis_ckpt_path)

    # tf FID
    tf_FID = build_GAN_metric(cfg=global_cfg.GAN_metric)

    class SampleFunc(object):
        def __init__(self, generator, batch, latent, gen_distribution, device):
            self.generator = generator
            self.batch = batch
            self.latent = latent
            self.gen_distribution = gen_distribution
            self.device = device
            pass

        def __call__(self, *args, **kwargs):
            with torch.no_grad():
                self.generator.eval()
                z = utils.sample_z(self.batch, self.latent, self.device,
                                   self.gen_distribution)
                pseudo_y = utils.sample_pseudo_labels(num_classes, self.batch,
                                                      self.device)
                fake_img = self.generator(z, pseudo_y)
            return fake_img

    sample_func = SampleFunc(gen,
                             batch=args.batch_size,
                             latent=args.gen_dim_z,
                             gen_distribution=args.gen_distribution,
                             device=device)

    # Training loop
    for n_iter in tqdm.tqdm(range(1, args.max_iteration + 1)):

        if n_iter >= args.lr_decay_start:
            decay_lr(opt_gen, args.max_iteration, args.lr_decay_start, args.lr)
            decay_lr(opt_dis, args.max_iteration, args.lr_decay_start, args.lr)

        # ==================== Beginning of 1 iteration. ====================
        _l_g = .0
        cumulative_loss_dis = .0
        for i in range(args.n_dis):
            if i == 0:
                fake, pseudo_y, _ = sample_from_gen(args, device, num_classes,
                                                    gen)
                dis_fake = dis(fake, pseudo_y)
                if args.relativistic_loss:
                    real, y = sample_from_data(args, device, train_loader)
                    dis_real = dis(real, y)
                else:
                    dis_real = None

                loss_gen = gen_criterion(dis_fake, dis_real)
                gen.zero_grad()
                loss_gen.backward()
                opt_gen.step()
                _l_g += loss_gen.item()
                if n_iter % 10 == 0 and writer is not None:
                    writer.add_scalar('gen', _l_g, n_iter)

            fake, pseudo_y, _ = sample_from_gen(args, device, num_classes, gen)
            real, y = sample_from_data(args, device, train_loader)

            dis_fake, dis_real = dis(fake, pseudo_y), dis(real, y)
            loss_dis = dis_criterion(dis_fake, dis_real)

            dis.zero_grad()
            loss_dis.backward()
            opt_dis.step()

            cumulative_loss_dis += loss_dis.item()
            if n_iter % 10 == 0 and i == args.n_dis - 1 and writer is not None:
                cumulative_loss_dis /= args.n_dis
                writer.add_scalar('dis', cumulative_loss_dis / args.n_dis,
                                  n_iter)
        # ==================== End of 1 iteration. ====================

        if n_iter % args.log_interval == 0 or n_iter == 1:
            tqdm.tqdm.write(
                'iteration: {:07d}/{:07d}, loss gen: {:05f}, loss dis {:05f}'.
                format(n_iter, args.max_iteration, _l_g, cumulative_loss_dis))
            if not args.no_image:
                writer.add_image(
                    'fake',
                    torchvision.utils.make_grid(fake,
                                                nrow=4,
                                                normalize=True,
                                                scale_each=True))
                writer.add_image(
                    'real',
                    torchvision.utils.make_grid(real,
                                                nrow=4,
                                                normalize=True,
                                                scale_each=True))
            # Save previews
            utils.save_images(n_iter, n_iter // args.checkpoint_interval,
                              args.results_root, args.train_image_root, fake,
                              real)
        if n_iter % args.checkpoint_interval == 0:
            # Save checkpoints!
            utils.save_checkpoints(args, n_iter,
                                   n_iter // args.checkpoint_interval, gen,
                                   opt_gen, dis, opt_dis)
        if (n_iter % args.eval_interval == 0
                or n_iter == 1) and eval_loader is not None:
            # TODO (crcrpar): implement Ineption score, FID, and Geometry score
            # Once these criterion are prepared, val_loader will be used.
            fid_score = evaluation.evaluate(args, n_iter, gen, device,
                                            inception_model, eval_loader)
            tqdm.tqdm.write(
                '[Eval] iteration: {:07d}/{:07d}, FID: {:07f}'.format(
                    n_iter, args.max_iteration, fid_score))
            if writer is not None:
                writer.add_scalar("FID", fid_score, n_iter)
                # Project embedding weights if exists.
                embedding_layer = getattr(dis, 'l_y', None)
                if embedding_layer is not None:
                    writer.add_embedding(embedding_layer.weight.data,
                                         list(range(args.num_classes)),
                                         global_step=n_iter)
        if n_iter % global_cfg.eval_FID_every == 0 or n_iter == 1:
            FID_tf, IS_mean_tf, IS_std_tf = tf_FID(sample_func=sample_func)
            logger.info(
                f'IS_mean_tf:{IS_mean_tf:.3f} +- {IS_std_tf:.3f}\n\tFID_tf: {FID_tf:.3f}'
            )
            if not math.isnan(IS_mean_tf):
                summary_d = {}
                summary_d['FID_tf'] = FID_tf
                summary_d['IS_mean_tf'] = IS_mean_tf
                summary_d['IS_std_tf'] = IS_std_tf
                summary_dict2txtfig(summary_d,
                                    prefix='train',
                                    step=n_iter,
                                    textlogger=global_textlogger)
            gen.train()
    if args.test:
        shutil.rmtree(args.results_root)
Example #6
0
def train(train_loader, model, criterion, optimizer, epoch, args):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(train_loader),
                             [batch_time, data_time, losses, top1, top5],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    global_itr = epoch * len(train_loader)

    end = time.time()
    for i, (images, _) in enumerate(train_loader):
        global_itr += 1
        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpu is not None:
            images[0] = images[0].cuda(args.gpu, non_blocking=True)
            images[1] = images[1].cuda(args.gpu, non_blocking=True)

        # compute output
        output, target = model(im_q=images[0], im_k=images[1])
        loss = criterion(output, target)

        # acc1/acc5 are (K+1)-way contrast classifier accuracy
        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images[0].size(0))
        top1.update(acc1[0], images[0].size(0))
        top5.update(acc5[0], images[0].size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
            if args.gpu == 0:
                summary_d = {
                    "Loss": losses.val,
                    "top1": top1.val,
                    "top5": top5.val
                }
                summary_dict2txtfig(summary_d,
                                    prefix='itr',
                                    step=global_itr,
                                    textlogger=global_textlogger)

    if args.gpu == 0:
        summary_d = {"Loss": losses.avg, "top1": top1.avg, "top5": top5.avg}
        summary_dict2txtfig(summary_d,
                            prefix='epoch',
                            step=global_itr,
                            textlogger=global_textlogger)
Example #7
0
def train(train_loader, model, criterion, optimizer, epoch, args):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(train_loader),
                             [batch_time, data_time, losses, top1, top5],
                             prefix="Epoch: [{}]".format(epoch))
    """
    Switch to eval mode:
    Under the protocol of linear classification on frozen features/models,
    it is not legitimate to change any part of the pre-trained model.
    BatchNorm in train mode may revise running mean/std (even if it receives
    no gradient), which are part of the model parameters too.
    """
    model.eval()

    global_itr = epoch * len(train_loader)
    end = time.time()
    for i, (images, target) in enumerate(train_loader):
        global_itr += 1
        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpu is not None:
            images = images.cuda(args.gpu, non_blocking=True)
        target = target.cuda(args.gpu, non_blocking=True)

        # compute output
        output = model(images)
        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
            if args.gpu == 0:
                summary_d = {
                    "Loss": losses.val,
                    "top1": top1.val,
                    "top5": top5.val
                }
                summary_dict2txtfig(summary_d,
                                    prefix='itr',
                                    step=global_itr,
                                    textlogger=global_textlogger)
            if args.tl_debug: break

    if args.gpu == 0:
        summary_d = {"Loss": losses.avg, "top1": top1.avg, "top5": top5.avg}
        summary_dict2txtfig(summary_d,
                            prefix='epoch',
                            step=global_itr,
                            textlogger=global_textlogger)
Example #8
0
def test_for_one_epoch(model, loss, test_loader, epoch_number):
    model.eval()
    loss.eval()

    data_time_meter = utils.AverageMeter()
    batch_time_meter = utils.AverageMeter()
    loss_meter = utils.AverageMeter(recent=100)
    top1_meter = utils.AverageMeter(recent=100)
    top5_meter = utils.AverageMeter(recent=100)

    timestamp = time.time()
    for i, (images, labels) in enumerate(test_loader):
        batch_size = images.size(0)

        if utils.is_model_cuda(model):
            images = images.cuda()
            labels = labels.cuda()

        # Record data time
        data_time_meter.update(time.time() - timestamp)

        # Forward pass without computing gradients.
        with torch.no_grad():
            outputs = model(images)
            loss_output = loss(outputs, labels)

        # Sometimes loss function returns a modified version of the output,
        # which must be used to compute the model accuracy.
        if isinstance(loss_output, tuple):
            loss_value, outputs = loss_output
        else:
            loss_value = loss_output

        # Record loss and model accuracy.
        loss_meter.update(loss_value.item(), batch_size)
        top1, top5 = utils.topk_accuracy(outputs, labels, recalls=(1, 5))
        top1_meter.update(top1, batch_size)
        top5_meter.update(top5, batch_size)

        # Record batch time
        batch_time_meter.update(time.time() - timestamp)
        timestamp = time.time()

        logging.info(
            'Epoch: [{epoch}][{batch}/{epoch_size}]\t'
            'Time {batch_time.value:.2f} ({batch_time.average:.2f})   '
            'Data {data_time.value:.2f} ({data_time.average:.2f})   '
            'Loss {loss.value:.3f} {{{loss.average:.3f}, {loss.average_recent:.3f}}}    '
            'Top-1 {top1.value:.2f} {{{top1.average:.2f}, {top1.average_recent:.2f}}}    '
            'Top-5 {top5.value:.2f} {{{top5.average:.2f}, {top5.average_recent:.2f}}}    '
            .format(epoch=epoch_number,
                    batch=i + 1,
                    epoch_size=len(test_loader),
                    batch_time=batch_time_meter,
                    data_time=data_time_meter,
                    loss=loss_meter,
                    top1=top1_meter,
                    top5=top5_meter))
        if getattr(global_cfg, 'val_dummy', False):
            break
    # Log the overall test stats
    logging.info('Epoch: [{epoch}] -- TESTING SUMMARY\t'
                 'Time {batch_time.sum:.2f}   '
                 'Data {data_time.sum:.2f}   '
                 'Loss {loss.average:.3f}     '
                 'Top-1 {top1.average:.2f}    '
                 'Top-5 {top5.average:.2f}    '.format(
                     epoch=epoch_number,
                     batch_time=batch_time_meter,
                     data_time=data_time_meter,
                     loss=loss_meter,
                     top1=top1_meter,
                     top5=top5_meter))
    summary_d = {}
    summary_d['batch_time_meter'] = batch_time_meter.sum
    summary_d['data_time_meter'] = data_time_meter.sum
    summary_d['loss_meter'] = loss_meter.average
    summary_d['top1_meter'] = top1_meter.average
    summary_d['top5_meter'] = top5_meter.average
    summary_dict2txtfig(dict_data=summary_d,
                        prefix='val',
                        step=epoch_number,
                        textlogger=global_textlogger)
    return top1_meter.average