def evaluate(dataloader, model, criterion, accuracy, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    losses = AverageMeter()
    accs = AverageMeter()
    with torch.no_grad():
        for i, data in enumerate(tqdm(dataloader)):
            inputs = data["input"].to(device)
            labels = data["label"].to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            acc = accuracy(outputs, labels)
            losses.update(loss.item(), outputs.size(0))
            accs.update(acc.item(), outputs.size(0))

    print("eval loss %0.5f acc %0.5f " % (losses.avg, accs.avg))
    return float(losses.avg), float(accs.avg)
Beispiel #2
0
def evaluate(dataloader,
             model,
             criterion,
             accuracy,
             static_augmentations=[],
             device=None,
             random_seed=123):
    print("evaluating...")
    if random_seed is not None:
        utils.make_deterministic(random_seed)

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    losses = AverageMeter()
    accs = []
    nway = dataloader.batch_sampler.n_way
    nshot = dataloader.batch_sampler.n_shot
    nquery = dataloader.batch_sampler.n_query
    with torch.no_grad():
        for i, data in enumerate(tqdm(dataloader)):
            inputs = data["input"].to(device)
            labels = data["label"].to(device)
            inputs_generated = None
            if model.mixer is not None or "generated" in static_augmentations:
                inputs_generated = data["generated"].to(device)
            print_final_nshot = False
            if i == 0:
                print_final_nshot = True
            outputs, query_labels = model(inputs,
                                          labels,
                                          nway,
                                          nshot,
                                          nquery,
                                          inputs_generated=inputs_generated,
                                          print_final_nshot=print_final_nshot,
                                          augmentations=static_augmentations)
            loss = criterion(outputs, query_labels)
            acc = accuracy(outputs, query_labels)

            losses.update(loss.item(), outputs.size(0))
            accs.append(acc.item())

    print("eval loss: %0.5f " % losses.avg)
    acc = float(np.mean(accs))
    conf = float(1.96 * np.std(accs) / np.sqrt(len(accs)))
    print("eval acc :%0.5f +- %0.5f" % (acc, conf))
    return float(losses.avg), acc, conf
Beispiel #3
0
def main(args):
    device = gpu_setup(args.gpu)
    append_args = ["dataset", "model"]
    checkpoint_dir = savedir_setup(args.savedir,
                                   args=args,
                                   append_args=append_args,
                                   basedir=args.saveroot)

    args.githash = check_githash()
    save_args(checkpoint_dir, args)

    dataloader = setup_dataloader(
        name=args.dataset,
        batch_size=args.batch,
        num_workers=args.workers,
    )

    dataset_size = len(dataloader.dataset)
    print("number of images (dataset size): ", dataset_size)

    model = setup_model(args.model,
                        dataset_size=dataset_size,
                        resume=args.resume,
                        biggan_imagenet_pretrained_model_path=args.pretrained)
    model.eval()
    #this has to be eval() even if it's training time
    #because we want to fix batchnorm running mean and var
    #still tune batchnrom scale and bias that is generated by linear layer in biggan

    optimizer, scheduler = setup_optimizer(
        model,
        lr_g_linear=args.lr_g_l,
        lr_g_batch_stat=args.lr_g_batch_stat,
        lr_bsa_linear=args.lr_bsa_l,
        lr_embed=args.lr_embed,
        lr_class_cond_embed=args.lr_c_embed,
        step=args.step,
        step_facter=args.step_facter,
    )

    criterion = AdaBIGGANLoss(
        scale_per=args.loss_per,
        scale_emd=args.loss_emd,
        scale_reg=args.loss_re,
        normalize_img=args.loss_norm_img,
        normalize_per=args.loss_norm_per,
        dist_per=args.loss_dist_per,
    )

    #start trainig loop
    losses = AverageMeter()
    print_freq = args.print_freq
    eval_freq = args.eval_freq
    save_freq = eval_freq
    max_iteration = args.iters
    log = {}
    log["log"] = []
    since = time.time()

    iteration = 0
    epoch = 0
    #prepare model and loss into device
    model = model.to(device)
    criterion = criterion.to(device)
    while (True):
        # Iterate over dataset (one epoch).
        for data in dataloader:
            img = data[0].to(device)
            indices = data[1].to(device)

            scheduler.step()

            #embeddings (i.e. z) + noise (i.e. epsilon)
            embeddings = model.embeddings(indices)
            embeddings_eps = torch.randn(embeddings.size(),
                                         device=device) * 0.01
            #see https://github.com/nogu-atsu/SmallGAN/blob/f604cd17516963d8eec292f3faddd70c227b609a/gen_models/ada_generator.py#L29

            #forward
            img_generated = model(embeddings + embeddings_eps)
            loss = criterion(img_generated, img, embeddings,
                             model.linear.weight)
            losses.update(loss.item(), img.size(0))

            #compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if iteration % print_freq == 0:
                temp = "train loss: %0.5f " % loss.item()
                temp += "| smoothed loss %0.5f " % losses.avg
                log["log"].append({
                    "iteration": iteration,
                    "epoch": epoch,
                    "loss": losses.avg
                })
                print(iteration, temp)
                losses = AverageMeter()

            if iteration % eval_freq == 0 and iteration > 0:
                out_path = os.path.join(checkpoint_dir,
                                        "%d_recon.jpg" % iteration)
                generate_samples(model, out_path, dataloader.batch_size)

            if iteration % save_freq == 0 and iteration > 0:
                save_checkpoint(checkpoint_dir,
                                device,
                                model,
                                iteration=iteration)

            if iteration > max_iteration:
                break
            iteration += 1

        if iteration > max_iteration:
            break
        epoch += 1

    log_save_path = os.path.join(checkpoint_dir, "train-log.json")
    save_json(log, log_save_path)
Beispiel #4
0
def main(args):
    device = gpu_setup(args.gpu)
    append_args = ["dataset", "data_num", "model"]
    if args.mode == 'train':
        checkpoint_dir = savedir_setup(args.savedir,
                                       args=args,
                                       append_args=append_args,
                                       basedir=args.saveroot)
        args.githash = check_githash()
        save_args(checkpoint_dir, args)

    dataloader = setup_dataloader(
        name=args.dataset,
        batch_size=args.batch,
        num_workers=args.workers,
        data_num=args.data_num,
    )

    dataset_size = len(dataloader.dataset)
    print("number of images (dataset size): ", dataset_size)

    if args.model == "biggan128-ada":
        model = setup_model(
            args.model,
            dataset_size=dataset_size,
            resume=args.resume,
            biggan_imagenet_pretrained_model_path=args.pretrained)
    elif args.model == "biggan128-conv1x1":
        model = setup_model(
            args.model,
            dataset_size=dataset_size,
            resume=args.resume,
            biggan_imagenet_pretrained_model_path=args.pretrained)
    elif args.model == 'biggan128-conv1x1-2' or args.model == 'biggan128-conv1x1-3':
        model = setup_model(
            args.model,
            dataset_size=dataset_size,
            resume=args.resume,
            biggan_imagenet_pretrained_model_path=args.pretrained,
            per_groups=args.per_groups)
    else:
        print('Error: Model not defined')
        sys.exit(1)
    model.eval()
    #this has to be eval() even if it's training time
    #because we want to fix batchnorm running mean and var
    #still tune batchnrom scale and bias that is generated by linear layer in biggan

    optimizer, scheduler = setup_optimizer(
        model_name=args.model,
        model=model,
        lr_g_linear=args.lr_g_l,
        lr_g_batch_stat=args.lr_g_batch_stat,
        lr_bsa_linear=args.lr_bsa_l,
        lr_embed=args.lr_embed,
        lr_class_cond_embed=args.lr_c_embed,
        step=args.step,
        step_facter=args.step_facter,
    )

    criterion = AdaBIGGANLoss(
        scale_per=args.loss_per,
        scale_emd=args.loss_emd,
        scale_reg=args.loss_re,
        normalize_img=args.loss_norm_img,
        normalize_per=args.loss_norm_per,
        dist_per=args.loss_dist_per,
    )

    #start trainig loop
    losses = AverageMeter()
    eval_kmmd = AverageMeter()
    print_freq = args.print_freq
    eval_freq = args.eval_freq
    save_freq = eval_freq
    max_epoch = args.epoch
    log = {}
    log["log"] = []
    since = time.time()

    min_loss = {
        'data': 1000,
        'epoch': 0,
    }

    epoch = 0
    #prepare model and loss into device
    model = model.to(device)
    criterion = criterion.to(device)
    while (True):

        if args.mode == 'train': scheduler.step()

        # Iterate over dataset (one epoch).
        for data in dataloader:
            define_seed(SEED)
            img = data[0].to(device)
            indices = data[1].to(device)

            #embeddings (i.e. z) + noise (i.e. epsilon)
            embeddings = model.embeddings(indices)
            embeddings_eps = torch.randn(embeddings.size(),
                                         device=device) * 0.01
            #see https://github.com/nogu-atsu/SmallGAN/blob/f604cd17516963d8eec292f3faddd70c227b609a/gen_models/ada_generator.py#L29
            #forward
            img_generated = model(embeddings + embeddings_eps)

            if args.mode == 'train':
                loss = criterion(img_generated, img, embeddings,
                                 model.linear.weight)
                losses.update(loss.item(), img.size(0))
                #compute gradient and do SGD step
                # optimizer.zero_grad()
                # loss.backward()
                # optimizer.step()

            elif args.mode == 'eval':
                if args.KMMD:
                    # KMMD
                    latent_size = embeddings.size(1)
                    true_sample = torch.randn(args.batch,
                                              latent_size,
                                              requires_grad=False).to(device)
                    kmmd = KMMD()(true_sample, embeddings)
                    eval_kmmd.update(kmmd.item(), img.size(0))

        if args.mode == 'train':
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if epoch % print_freq == 0:
                if min_loss['data'] > losses.avg:
                    min_loss['data'] = losses.avg
                    min_loss['epoch'] = epoch
                temp = "train loss: %0.5f " % loss.item()
                temp += "| smoothed loss %0.5f " % losses.avg
                temp += "| min loss %0.5f (%d)" % (min_loss['data'],
                                                   min_loss['epoch'])
                log["log"].append({"epoch": epoch, "loss": losses.avg})
                print(epoch, temp)
                losses = AverageMeter()
            if epoch % eval_freq == 0:
                img_prefix = os.path.join(checkpoint_dir, "%d_" % epoch)
                generate_samples(model, img_prefix, dataloader.batch_size)

            if epoch % save_freq == 0:
                save_checkpoint(checkpoint_dir, device, model, iteration=epoch)

        elif args.mode == 'eval':
            if epoch * args.data_num > 10000:  # eval用のデータ数は10,000
                # if epoch * args.batch > 10000: # eval用のデータ数は10,000
                if args.KMMD:
                    print('KMMD:', eval_kmmd.avg)
                if args.FID:
                    out_path = "./outputs/" + args.resume.split('/')[2] + '/'
                    os.mkdir(out_path)
                    for i in range(1000):
                        visualizers.random_eval(model,
                                                out_path,
                                                tmp=0.3,
                                                n=1,
                                                truncate=True,
                                                roop_n=i)
                break

        if epoch > max_epoch:
            break
        epoch += 1

    if args.mode == 'train':
        log_save_path = os.path.join(checkpoint_dir, "train-log.json")
        save_json(log, log_save_path)
Beispiel #5
0
def train_one_epoch(dataloader,
                    model,
                    criterion,
                    optimizer,
                    accuracy=accuracy,
                    device=None,
                    print_freq=100,
                    random_seed=None):
    if random_seed is not None:
        #be careful to use this!
        #it's okay to fix seed every time we call evaluate() because we want to have exactly same order of test images
        #HOWEVER, for training time, we want to have different orders of training images for each epoch.
        #to do this, we can set the seed as epoch, for example.
        utils.make_deterministic(random_seed)

    since = time.time()
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.train()  # Set model to training mode

    losses = AverageMeter()
    accs = AverageMeter()

    suprevised_baseline = False
    if hasattr(dataloader.batch_sampler, "n_way"):
        nway = dataloader.batch_sampler.n_way
        nshot = dataloader.batch_sampler.n_shot
        nquery = dataloader.batch_sampler.n_query
    else:
        suprevised_baseline = True

    for i, data in enumerate(tqdm(dataloader)):
        inputs = data["input"].to(device)
        labels = data["label"].to(device)

        if suprevised_baseline:
            #this is a baseline without meta-learning
            inputs = model.embed_samples(inputs)
            outputs = model.classifier(inputs)
            query_labels = labels
        else:
            inputs_generated = None
            if model.mixer is not None:
                inputs_generated = data["generated"].to(device)
            print_final_nshot = False
            if i == 0:
                print_final_nshot = True
            outputs, query_labels = model(inputs,
                                          labels,
                                          nway,
                                          nshot,
                                          nquery,
                                          inputs_generated=inputs_generated,
                                          print_final_nshot=print_final_nshot)

        loss = criterion(outputs, query_labels)
        acc = accuracy(outputs, query_labels)
        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure acc and record loss
        losses.update(loss.item(), outputs.size(0))
        accs.update(acc.item(), outputs.size(0))

        if i % print_freq == 0 or i == len(dataloader) - 1:
            temp = "current loss: %0.5f " % loss.item()
            temp += "acc %0.5f " % acc.item()
            temp += "| running average loss %0.5f " % losses.avg
            temp += "acc %0.5f " % accs.avg
            print(i, temp)

    time_elapsed = time.time() - since
    print('this epoch took {:.0f}m {:.0f}s'.format(time_elapsed // 60,
                                                   time_elapsed % 60))
    return float(losses.avg), float(accs.avg)