Пример #1
0
    return encoded_list, label_list


# Train model
if args.model != 'PCA':
    log = []
    t_total = time.time()
    for epoch in range(args.epochs):
        log_epoch = train(epoch)
        log.append(log_epoch)

    print("Optimization Finished!")
    print("Total time elapsed: {:.4f}s".format(time.time() - t_total))

# Test
encoded_list, label_list = test()

if args.encoded is not None:
    np.savez(args.encoded, encoded=encoded_list, label=label_list)


# Save File
if args.log is not None:
    np.savetxt(args.log, log)

if args.o is not None:
    torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
               }, args.o)
Пример #2
0
def train():
    from benchmark import calc_fid, extract_feature_from_generator_fn, load_patched_inception_v3, real_image_loader, image_generator, image_generator_perm
    import lpips

    from config import IM_SIZE_GAN, BATCH_SIZE_GAN, NFC, NBR_CLS, DATALOADER_WORKERS, EPOCH_GAN, ITERATION_AE, GAN_CKECKPOINT
    from config import SAVE_IMAGE_INTERVAL, SAVE_MODEL_INTERVAL, LOG_INTERVAL, SAVE_FOLDER, TRIAL_NAME, DATA_NAME, MULTI_GPU
    from config import FID_INTERVAL, FID_BATCH_NBR, PRETRAINED_AE_PATH
    from config import data_root_colorful, data_root_sketch_1, data_root_sketch_2, data_root_sketch_3

    real_features = None
    inception = load_patched_inception_v3().cuda()
    inception.eval()

    percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True)

    saved_image_folder = saved_model_folder = None
    log_file_path = None
    if saved_image_folder is None:
        saved_image_folder, saved_model_folder = make_folders(
            SAVE_FOLDER, 'GAN_' + TRIAL_NAME)
        log_file_path = saved_image_folder + '/../gan_log.txt'
        log_file = open(log_file_path, 'w')
        log_file.close()

    dataset = PairedMultiDataset(data_root_colorful,
                                 data_root_sketch_1,
                                 data_root_sketch_2,
                                 data_root_sketch_3,
                                 im_size=IM_SIZE_GAN,
                                 rand_crop=True)
    print('the dataset contains %d images.' % len(dataset))
    dataloader = iter(
        DataLoader(dataset,
                   BATCH_SIZE_GAN,
                   sampler=InfiniteSamplerWrapper(dataset),
                   num_workers=DATALOADER_WORKERS,
                   pin_memory=True))

    from datasets import ImageFolder
    from datasets import trans_maker_augment as trans_maker

    dataset_rgb = ImageFolder(data_root_colorful, trans_maker(512))
    dataset_skt = ImageFolder(data_root_sketch_3, trans_maker(512))

    net_ae = AE(nfc=NFC, nbr_cls=NBR_CLS)

    if PRETRAINED_AE_PATH is None:
        PRETRAINED_AE_PATH = 'train_results/' + 'AE_' + TRIAL_NAME + '/models/%d.pth' % ITERATION_AE
    else:
        from config import PRETRAINED_AE_ITER
        PRETRAINED_AE_PATH = PRETRAINED_AE_PATH + '/models/%d.pth' % PRETRAINED_AE_ITER

    net_ae.load_state_dicts(PRETRAINED_AE_PATH)
    net_ae.cuda()
    net_ae.eval()

    RefineGenerator = None
    if DATA_NAME == 'celeba':
        from models import RefineGenerator_face as RefineGenerator
    elif DATA_NAME == 'art' or DATA_NAME == 'shoe':
        from models import RefineGenerator_art as RefineGenerator
    net_ig = RefineGenerator(nfc=NFC, im_size=IM_SIZE_GAN).cuda()
    net_id = Discriminator(nc=3).cuda(
    )  # we use the patch_gan, so the im_size for D should be 512 even if training image size is 1024

    if MULTI_GPU:
        net_ae = nn.DataParallel(net_ae)
        net_ig = nn.DataParallel(net_ig)
        net_id = nn.DataParallel(net_id)

    net_ig_ema = copy_G_params(net_ig)

    opt_ig = optim.Adam(net_ig.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_id = optim.Adam(net_id.parameters(), lr=2e-4, betas=(0.5, 0.999))

    if GAN_CKECKPOINT is not None:
        ckpt = torch.load(GAN_CKECKPOINT)
        net_ig.load_state_dict(ckpt['ig'])
        net_id.load_state_dict(ckpt['id'])
        net_ig_ema = ckpt['ig_ema']
        opt_ig.load_state_dict(ckpt['opt_ig'])
        opt_id.load_state_dict(ckpt['opt_id'])

    ## create a log file
    losses_g_img = AverageMeter()
    losses_d_img = AverageMeter()
    losses_mse = AverageMeter()
    losses_rec_s = AverageMeter()

    losses_rec_ae = AverageMeter()

    fixed_skt = fixed_rgb = fixed_perm = None

    fid = [[0, 0]]

    for epoch in range(EPOCH_GAN):
        for iteration in tqdm(range(10000)):
            rgb_img, skt_img_1, skt_img_2, skt_img_3 = next(dataloader)

            rgb_img = rgb_img.cuda()

            rd = random.randint(0, 3)
            if rd == 0:
                skt_img = skt_img_1.cuda()
            elif rd == 1:
                skt_img = skt_img_2.cuda()
            else:
                skt_img = skt_img_3.cuda()

            if iteration == 0:
                fixed_skt = skt_img_3[:8].clone().cuda()
                fixed_rgb = rgb_img[:8].clone()
                fixed_perm = true_randperm(fixed_rgb.shape[0], 'cuda')

            ### 1. train D
            gimg_ae, style_feats = net_ae(skt_img, rgb_img)
            g_image = net_ig(gimg_ae, style_feats)

            pred_r = net_id(rgb_img)
            pred_f = net_id(g_image.detach())

            loss_d = d_hinge_loss(pred_r, pred_f)

            net_id.zero_grad()
            loss_d.backward()
            opt_id.step()

            loss_rec_ae = F.mse_loss(gimg_ae, rgb_img) + F.l1_loss(
                gimg_ae, rgb_img)
            losses_rec_ae.update(loss_rec_ae.item(), BATCH_SIZE_GAN)

            ### 2. train G
            pred_g = net_id(g_image)
            loss_g = g_hinge_loss(pred_g)

            if DATA_NAME == 'shoe':
                loss_mse = 10 * (F.l1_loss(g_image, rgb_img) +
                                 F.mse_loss(g_image, rgb_img))
            else:
                loss_mse = 10 * percept(
                    F.adaptive_avg_pool2d(g_image, output_size=256),
                    F.adaptive_avg_pool2d(rgb_img, output_size=256)).sum()
            losses_mse.update(loss_mse.item() / BATCH_SIZE_GAN, BATCH_SIZE_GAN)

            loss_all = loss_g + loss_mse

            if DATA_NAME == 'shoe':
                ### the grey image reconstruction
                perm = true_randperm(BATCH_SIZE_GAN)
                img_ae_perm, style_feats_perm = net_ae(skt_img, rgb_img[perm])

                gimg_grey = net_ig(img_ae_perm, style_feats_perm)
                gimg_grey = gimg_grey.mean(dim=1, keepdim=True)
                real_grey = rgb_img.mean(dim=1, keepdim=True)
                loss_rec_grey = F.mse_loss(gimg_grey, real_grey)
                loss_all += 10 * loss_rec_grey

            net_ig.zero_grad()
            loss_all.backward()
            opt_ig.step()

            for p, avg_p in zip(net_ig.parameters(), net_ig_ema):
                avg_p.mul_(0.999).add_(p.data, alpha=0.001)

            ### 3. logging
            losses_g_img.update(pred_g.mean().item(), BATCH_SIZE_GAN)
            losses_d_img.update(pred_r.mean().item(), BATCH_SIZE_GAN)

            if iteration % SAVE_IMAGE_INTERVAL == 0:  #show the current images
                with torch.no_grad():

                    backup_para_g = copy_G_params(net_ig)
                    load_params(net_ig, net_ig_ema)

                    gimg_ae, style_feats = net_ae(fixed_skt, fixed_rgb)
                    gmatch = net_ig(gimg_ae, style_feats)

                    gimg_ae_perm, style_feats = net_ae(fixed_skt,
                                                       fixed_rgb[fixed_perm])
                    gmismatch = net_ig(gimg_ae_perm, style_feats)

                    gimg = torch.cat([
                        F.interpolate(fixed_rgb, IM_SIZE_GAN),
                        F.interpolate(fixed_skt.repeat(1, 3, 1, 1),
                                      IM_SIZE_GAN), gmatch,
                        F.interpolate(gimg_ae, IM_SIZE_GAN), gmismatch,
                        F.interpolate(gimg_ae_perm, IM_SIZE_GAN)
                    ])

                    vutils.save_image(
                        gimg,
                        f'{saved_image_folder}/img_iter_{epoch}_{iteration}.jpg',
                        normalize=True,
                        range=(-1, 1))
                    del gimg

                    make_matrix(
                        dataset_rgb, dataset_skt, net_ae, net_ig, 5,
                        f'{saved_image_folder}/img_iter_{epoch}_{iteration}_matrix.jpg'
                    )

                    load_params(net_ig, backup_para_g)

            if iteration % LOG_INTERVAL == 0:
                log_msg = 'Iter: [{0}/{1}] G: {losses_g_img.avg:.4f}  D: {losses_d_img.avg:.4f}  MSE: {losses_mse.avg:.4f}  Rec: {losses_rec_s.avg:.5f}  FID: {fid:.4f}'.format(
                    epoch,
                    iteration,
                    losses_g_img=losses_g_img,
                    losses_d_img=losses_d_img,
                    losses_mse=losses_mse,
                    losses_rec_s=losses_rec_s,
                    fid=fid[-1][0])

                print(log_msg)
                print('%.5f' % (losses_rec_ae.avg))

                if log_file_path is not None:
                    log_file = open(log_file_path, 'a')
                    log_file.write(log_msg + '\n')
                    log_file.close()

                losses_g_img.reset()
                losses_d_img.reset()
                losses_mse.reset()
                losses_rec_s.reset()
                losses_rec_ae.reset()

            if iteration % SAVE_MODEL_INTERVAL == 0 or iteration + 1 == 10000:
                print('Saving history model')
                torch.save(
                    {
                        'ig': net_ig.state_dict(),
                        'id': net_id.state_dict(),
                        'ae': net_ae.state_dict(),
                        'ig_ema': net_ig_ema,
                        'opt_ig': opt_ig.state_dict(),
                        'opt_id': opt_id.state_dict(),
                    }, '%s/%d.pth' % (saved_model_folder, epoch))

            if iteration % FID_INTERVAL == 0 and iteration > 1:
                print("calculating FID ...")
                fid_batch_images = FID_BATCH_NBR
                if real_features is None:
                    if os.path.exists('%s_fid_feats.npy' % (DATA_NAME)):
                        real_features = pickle.load(
                            open('%s_fid_feats.npy' % (DATA_NAME), 'rb'))
                    else:
                        real_features = extract_feature_from_generator_fn(
                            real_image_loader(dataloader,
                                              n_batches=fid_batch_images),
                            inception)
                        real_mean = np.mean(real_features, 0)
                        real_cov = np.cov(real_features, rowvar=False)
                        pickle.dump(
                            {
                                'feats': real_features,
                                'mean': real_mean,
                                'cov': real_cov
                            }, open('%s_fid_feats.npy' % (DATA_NAME), 'wb'))
                        real_features = pickle.load(
                            open('%s_fid_feats.npy' % (DATA_NAME), 'rb'))

                sample_features = extract_feature_from_generator_fn(
                    image_generator(dataset,
                                    net_ae,
                                    net_ig,
                                    n_batches=fid_batch_images),
                    inception,
                    total=fid_batch_images)
                cur_fid = calc_fid(sample_features,
                                   real_mean=real_features['mean'],
                                   real_cov=real_features['cov'])
                sample_features_perm = extract_feature_from_generator_fn(
                    image_generator_perm(dataset,
                                         net_ae,
                                         net_ig,
                                         n_batches=fid_batch_images),
                    inception,
                    total=fid_batch_images)
                cur_fid_perm = calc_fid(sample_features_perm,
                                        real_mean=real_features['mean'],
                                        real_cov=real_features['cov'])

                fid.append([cur_fid, cur_fid_perm])
                print('fid:', fid)
                if log_file_path is not None:
                    log_file = open(log_file_path, 'a')
                    log_msg = 'fid: %.5f, %.5f' % (fid[-1][0], fid[-1][1])
                    log_file.write(log_msg + '\n')
                    log_file.close()
Пример #3
0
class BiomeAE():
    def __init__(self, args):
        if args.model in ["BiomeAEL0"]:
            self.mlp_type = "L0"
        else:
            self.mlp_type = None

        self.model_alias = args.model_alias
        self.model= args.model
        self.snap_loc = os.path.join(args.vis_dir, "snap.pt")

        #tl.configure("runs/ds.{}".format(model_alias))
        #tl.log_value(model_alias, 0)

        """ no stat file needed for now
        stat_alias = 'obj_DataStat+%s_%s' % (args.dataset_name, args.dataset_subset)
        stat_path = os.path.join(
            output_dir, '%s.pkl' % (stat_alias)
        )
        with open(stat_path,'rb') as sf:
            data_stats = pickle.load(sf)
        """

        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num
        self.predictor = None

        torch.manual_seed(args.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(args.seed)

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


    def get_transformation(self):
        return None

    def loss_fn(self, recon_x, x, mean, log_var):
        if self.model in ["BiomeAE","BiomeAESnip"]:
            mseloss = torch.nn.MSELoss()
            return torch.sqrt(mseloss(recon_x, x))
        if self.model in ["BiomeAEL0"]:
            mseloss = torch.nn.MSELoss()
            return torch.sqrt(mseloss(recon_x, x))+self.predictor.regularization()
        elif self.model =="BiomeVAE":
            BCE = torch.nn.functional.binary_cross_entropy(
                recon_x.view(-1, 28*28), x.view(-1, 28*28), reduction='sum')
            KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
            return (BCE + KLD) / x.size(0)

    def param_l0(self):
        return self.predictor.param_l0()

    def init_fit(self, X1_train, X2_train, y_train, X1_val, X2_val, y_val, args, ):
        self.train_loader = get_dataloader (X1_train, X2_train, y_train, args.batch_size)
        self.test_loader = get_dataloader(X1_val, X2_val, y_val, args.batch_size)

        self.predictor = AE(
            encoder_layer_sizes=[X1_train.shape[1]],
            latent_size=args.latent_size,
            decoder_layer_sizes=[X2_train.shape[1]],
            activation=args.activation,
            batch_norm= args.batch_norm,
            dropout=args.dropout,
            mlp_type=self.mlp_type,
            conditional=args.conditional,
            num_labels=10 if args.conditional else 0).to(self.device)
        self.optimizer = torch.optim.Adam(self.predictor.parameters(), lr=args.learning_rate)
        self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.8)

    def train(self, args):
        if args.contr:
            print("Loading from ", self.snap_loc)
            loaded_model_para = torch.load(self.snap_loc)
            self.predictor.load_state_dict(loaded_model_para)

        t = 0
        logs = defaultdict(list)
        iterations_per_epoch = len(self.train_loader.dataset) / args.batch_size
        num_iterations = int(iterations_per_epoch * args.epochs)
        for epoch in range(args.epochs):

            tracker_epoch = defaultdict(lambda: defaultdict(dict))

            for iteration, (x1, x2, y) in enumerate(self.train_loader):
                t+=1

                x1, x2, y = x1.to(self.device), x2.to(self.device), y.to(self.device)

                if args.conditional:
                    x2_hat, z, mean, log_var = self.predictor(x1, y)
                else:
                    x2_hat, z, mean, log_var = self.predictor(x1)

                for i, yi in enumerate(y):
                    id = len(tracker_epoch)
                    tracker_epoch[id]['x'] = z[i, 0].item()
                    tracker_epoch[id]['y'] = z[i, 1].item()
                    tracker_epoch[id]['label'] = yi.item()

                loss = self.loss_fn(x2_hat, x2, mean, log_var)

                self.optimizer.zero_grad()
                loss.backward()
                if (t + 1) % int(num_iterations / 10) == 0:
                    self.scheduler.step()
                self.optimizer.step()

                #enforce non-negative
                if args.nonneg_weight:
                    for layer in self.predictor.modules():
                        if isinstance(layer, nn.Linear):
                            layer.weight.data.clamp_(0.0)


                logs['loss'].append(loss.item())

                if iteration % args.print_every == 0 or iteration == len(self.train_loader)-1:
                    print("Epoch {:02d}/{:02d} Batch {:04d}/{:d}, Loss {:9.4f}".format(
                        epoch, args.epochs, iteration, len(self.train_loader)-1, loss.item()))
                    if args.model =="VAE":
                        if args.conditional:
                            c = torch.arange(0, 10).long().unsqueeze(1)
                            x = self.predictor.inference(n=c.size(0), c=c)
                        else:
                            x = self.predictor.inference(n=10)

        if not args.contr:
            print("Saving to ", self.snap_loc)
            torch.save(self.predictor.state_dict(), self.snap_loc)

    def fit(self,X1_train, X2_train, y_train, X1_val, X2_val, y_val, args,):
        self.init_fit(X1_train, X2_train, y_train, X1_val, X2_val, y_val, args)
        self.train(args)

    def get_graph(self):
        """
        return nodes and weights
        :return:
        """
        nodes = []
        weights = []
        for l, layer in enumerate(self.predictor.modules()):
            if isinstance(layer, nn.Linear):
                lin_layer =layer
                nodes.append(["%s"%(x) for x in list(range(lin_layer.in_features))])
                weights.append(lin_layer.weight.detach().cpu().numpy().T)
        nodes.append(["%s"%(x) for x in list(range(lin_layer.out_features))]) #last linear layer

        return (nodes, weights)

    def predict(self,X1_val, X2_val, y_val, args):
        #Batch test
        x1, x2, y = torch.FloatTensor(X1_val).to(self.device), torch.FloatTensor(X2_val).to(self.device), torch.FloatTensor(y_val).to(self.device)
        if args.conditional:
            x2_hat, z, mean, log_var = self.predictor(x1, y)
        else:
            x2_hat, z, mean, log_var = self.predictor(x1)
        val_loss = self.loss_fn( x2_hat, x2, mean, log_var)
        print("val_loss: {:9.4f}", val_loss.item())
        return x2_hat.detach().cpu().numpy()

    def transform(self,X1_val, X2_val, y_val, args):
        x1, x2, y = torch.FloatTensor(X1_val).to(self.device), torch.FloatTensor(X2_val).to(
            self.device), torch.FloatTensor(y_val).to(self.device)
        if args.conditional:
            x2_hat, z, mean, log_var = self.predictor(x1, y)
        else:
            x2_hat, z, mean, log_var = self.predictor(x1)

        return z.detach().cpu().numpy()

    def get_influence_matrix(self):
        return self.predictor.get_influence_matrix()
Пример #4
0
def data_aug(data, lr=0.001, epoch=800, batch_size=128):
    folder = 'data_aug'
    save_path = f'{folder}/data_augment.csv'
    clean_file(save_path)
    store_e = [700, 750, 800]
    if not os.path.exists(folder):
        os.makedirs(folder)
    else:
        for i in store_e:
            result = test(data, folder, i)
        return result

    train_loss_curve = []
    valid_loss_curve = []
    # load model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = AE()
    model = model.to(device)
    model.train()

    dataset = AEDataset(data)
    train_size = int(0.85 * len(dataset))
    valid_size = len(dataset) - train_size
    train_dataset, valid_dataset = random_split(dataset,
                                                [train_size, valid_size])
    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True)
    valid_dataloader = DataLoader(dataset=valid_dataset,
                                  batch_size=batch_size,
                                  shuffle=True)

    # loss function and optimizer
    # can change loss function and optimizer you want
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    best = 100
    # start training
    for e in range(epoch):
        train_loss = 0.0
        valid_loss = 0.0

        print(f'\nEpoch: {e+1}/{epoch}')
        print('-' * len(f'Epoch: {e+1}/{epoch}'))
        # tqdm to disply progress bar
        for inputs in tqdm(train_dataloader):
            # data from data_loader
            inputs = inputs.float().to(device)
            outputs = model(inputs, device)
            loss = criterion(outputs, inputs)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        for inputs in tqdm(valid_dataloader):
            # data from data_loader
            inputs = inputs.float().to(device)
            outputs = model(inputs, device)
            # MSE loss
            loss = criterion(outputs, inputs)
            # loss calculate
            valid_loss += loss.item()
        # save the best model weights as .pth file
        loss_epoch = train_loss / len(train_dataset)
        valid_loss_epoch = valid_loss / len(valid_dataset)
        # if valid_loss_epoch < best :
        #     best = valid_loss_epoch
        #     torch.save(model.state_dict(), 'data_aug.pth')
        if e in store_e:
            torch.save(model.state_dict(), f'{folder}/ep{e}data_aug.pth')
            print(f"training in epoch  {e},start augment data!!")
            result = test(data, folder, e)
        print(f'Training loss: {loss_epoch:.4f}')
        print(f'Valid loss: {valid_loss_epoch:.4f}')
        # save loss  every epoch
        train_loss_curve.append(loss_epoch)
        valid_loss_curve.append(valid_loss_epoch)
    # generate training curve
    # visualize(train_loss_curve,valid_loss_curve, 'Data Augmentation')
    return result