Example #1
0
def main():
    args = arguments()
    outdir = os.path.join(args.out, dt.now().strftime('%m%d_%H%M') + "_cgan")

    #    chainer.config.type_check = False
    chainer.config.autotune = True
    chainer.config.dtype = dtypes[args.dtype]
    chainer.print_runtime_info()
    #print('Chainer version: ', chainer.__version__)
    #print('GPU availability:', chainer.cuda.available)
    #print('cuDNN availability:', chainer.cuda.cudnn_enabled)
    if args.gpu >= 0:
        chainer.cuda.get_device(args.gpu).use()

    ## dataset preparation
    train_d = Dataset(args.train,
                      args.root,
                      args.from_col,
                      args.to_col,
                      clipA=args.clipA,
                      clipB=args.clipB,
                      class_num=args.class_num,
                      crop=(args.crop_height, args.crop_width),
                      imgtype=args.imgtype,
                      random=args.random_translate,
                      grey=args.grey,
                      BtoA=args.btoa)
    test_d = Dataset(args.val,
                     args.root,
                     args.from_col,
                     args.to_col,
                     clipA=args.clipA,
                     clipB=args.clipB,
                     class_num=args.class_num,
                     crop=(args.crop_height, args.crop_width),
                     imgtype=args.imgtype,
                     random=args.random_translate,
                     grey=args.grey,
                     BtoA=args.btoa)
    args.crop_height, args.crop_width = train_d.crop
    if (len(train_d) == 0):
        print("No images found!")
        exit()

    # setup training/validation data iterators
    train_iter = chainer.iterators.SerialIterator(train_d, args.batch_size)
    test_iter = chainer.iterators.SerialIterator(test_d,
                                                 args.nvis,
                                                 shuffle=False)
    test_iter_gt = chainer.iterators.SerialIterator(
        train_d, args.nvis,
        shuffle=False)  ## same as training data; used for validation

    args.ch = len(train_d[0][0])
    args.out_ch = len(train_d[0][1])
    print("Input channels {}, Output channels {}".format(args.ch, args.out_ch))
    if (len(train_d) * len(test_d) == 0):
        print("No images found!")
        exit()

    ## Set up models
    # shared pretrained layer
    if (args.gen_pretrained_encoder and args.gen_pretrained_lr_ratio == 0):
        if "resnet" in args.gen_pretrained_encoder:
            pretrained = L.ResNet50Layers()
            print("Pretrained ResNet model loaded.")
        else:
            pretrained = L.VGG16Layers()
            print("Pretrained VGG model loaded.")
        if args.gpu >= 0:
            pretrained.to_gpu()
        enc_x = net.Encoder(args, pretrained)
    else:
        enc_x = net.Encoder(args)


#    gen = net.Generator(args)
    dec_y = net.Decoder(args)

    if args.lambda_dis > 0:
        dis = net.Discriminator(args)
        models = {'enc_x': enc_x, 'dec_y': dec_y, 'dis': dis}
    else:
        dis = L.Linear(1, 1)
        models = {'enc_x': enc_x, 'dec_y': dec_y}

    ## load learnt models
    optimiser_files = []
    if args.model_gen:
        serializers.load_npz(args.model_gen, enc_x)
        serializers.load_npz(args.model_gen.replace('enc_x', 'dec_y'), dec_y)
        print('model loaded: {}, {}'.format(
            args.model_gen, args.model_gen.replace('enc_x', 'dec_y')))
        optimiser_files.append(args.model_gen.replace('enc_x', 'opt_enc_x'))
        optimiser_files.append(args.model_gen.replace('enc_x', 'opt_dec_y'))
    if args.model_dis:
        serializers.load_npz(args.model_dis, dis)
        print('model loaded: {}'.format(args.model_dis))
        optimiser_files.append(args.model_dis.replace('dis', 'opt_dis'))

    ## send models to GPU
    if args.gpu >= 0:
        enc_x.to_gpu()
        dec_y.to_gpu()
        dis.to_gpu()

    # Setup optimisers
    def make_optimizer(model, lr, opttype='Adam', pretrained_lr_ratio=1.0):
        #        eps = 1e-5 if args.dtype==np.float16 else 1e-8
        optimizer = optim[opttype](lr)
        optimizer.setup(model)
        if args.weight_decay > 0:
            if opttype in ['Adam', 'AdaBound', 'Eve']:
                optimizer.weight_decay_rate = args.weight_decay
            else:
                if args.weight_decay_norm == 'l2':
                    optimizer.add_hook(
                        chainer.optimizer.WeightDecay(args.weight_decay))
                else:
                    optimizer.add_hook(
                        chainer.optimizer_hooks.Lasso(args.weight_decay))
        return optimizer

    opt_enc_x = make_optimizer(enc_x, args.learning_rate_gen, args.optimizer)
    opt_dec_y = make_optimizer(dec_y, args.learning_rate_gen, args.optimizer)
    opt_dis = make_optimizer(dis, args.learning_rate_dis, args.optimizer)

    optimizers = {'enc_x': opt_enc_x, 'dec_y': opt_dec_y, 'dis': opt_dis}

    ## resume optimisers from file
    if args.load_optimizer:
        for (m, e) in zip(optimiser_files, optimizers):
            if m:
                try:
                    serializers.load_npz(m, optimizers[e])
                    print('optimiser loaded: {}'.format(m))
                except:
                    print("couldn't load {}".format(m))
                    pass

    # finetuning
    if args.gen_pretrained_encoder:
        if args.gen_pretrained_lr_ratio == 0:
            enc_x.base.disable_update()
        else:
            for func_name in enc_x.encoder.base._children:
                for param in enc_x.encoder.base[func_name].params():
                    param.update_rule.hyperparam.eta *= args.gen_pretrained_lr_ratio

    # Set up trainer
    updater = Updater(
        models=(enc_x, dec_y, dis),
        iterator={'main': train_iter},
        optimizer=optimizers,
        #        converter=convert.ConcatWithAsyncTransfer(),
        params={'args': args},
        device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=outdir)

    ## save learnt results at a specified interval or at the end of training
    if args.snapinterval < 0:
        args.snapinterval = args.epoch
    snapshot_interval = (args.snapinterval, 'epoch')
    display_interval = (args.display_interval, 'iteration')

    for e in models:
        trainer.extend(extensions.snapshot_object(models[e],
                                                  e + '{.updater.epoch}.npz'),
                       trigger=snapshot_interval)
        if args.parameter_statistics:
            trainer.extend(extensions.ParameterStatistics(
                models[e]))  ## very slow
    for e in optimizers:
        trainer.extend(extensions.snapshot_object(
            optimizers[e], 'opt_' + e + '{.updater.epoch}.npz'),
                       trigger=snapshot_interval)

    ## plot NN graph
    if args.lambda_rec_l1 > 0:
        trainer.extend(
            extensions.dump_graph('dec_y/loss_L1', out_name='enc.dot'))
    elif args.lambda_rec_l2 > 0:
        trainer.extend(
            extensions.dump_graph('dec_y/loss_L2', out_name='gen.dot'))
    elif args.lambda_rec_ce > 0:
        trainer.extend(
            extensions.dump_graph('dec_y/loss_CE', out_name='gen.dot'))
    if args.lambda_dis > 0:
        trainer.extend(
            extensions.dump_graph('dis/loss_real', out_name='dis.dot'))

    ## log outputs
    log_keys = ['epoch', 'iteration', 'lr']
    log_keys_gen = ['myval/loss_L1', 'myval/loss_L2']
    log_keys_dis = []
    if args.lambda_rec_l1 > 0:
        log_keys_gen.append('dec_y/loss_L1')
    if args.lambda_rec_l2 > 0:
        log_keys_gen.append('dec_y/loss_L2')
    if args.lambda_rec_ce > 0:
        log_keys_gen.extend(['dec_y/loss_CE', 'myval/loss_CE'])
    if args.lambda_reg > 0:
        log_keys.extend(['enc_x/loss_reg'])
    if args.lambda_tv > 0:
        log_keys_gen.append('dec_y/loss_tv')
    if args.lambda_dis > 0:
        log_keys_dis.extend(
            ['dec_y/loss_dis', 'dis/loss_real', 'dis/loss_fake'])
    if args.lambda_mispair > 0:
        log_keys_dis.append('dis/loss_mispair')
    if args.dis_wgan:
        log_keys_dis.extend(['dis/loss_gp'])
    trainer.extend(extensions.LogReport(trigger=display_interval))
    trainer.extend(extensions.PrintReport(log_keys + log_keys_gen +
                                          log_keys_dis),
                   trigger=display_interval)
    if extensions.PlotReport.available():
        #        trainer.extend(extensions.PlotReport(['lr'], 'iteration',trigger=display_interval, file_name='lr.png'))
        trainer.extend(
            extensions.PlotReport(log_keys_gen,
                                  'iteration',
                                  trigger=display_interval,
                                  file_name='loss_gen.png',
                                  postprocess=plot_log))
        trainer.extend(
            extensions.PlotReport(log_keys_dis,
                                  'iteration',
                                  trigger=display_interval,
                                  file_name='loss_dis.png'))
    trainer.extend(extensions.ProgressBar(update_interval=10))

    # learning rate scheduling
    trainer.extend(extensions.observe_lr(optimizer_name='enc_x'),
                   trigger=display_interval)
    if args.optimizer in ['Adam', 'AdaBound', 'Eve']:
        lr_target = 'eta'
    else:
        lr_target = 'lr'
    if args.lr_drop > 0:  ## cosine annealing
        for e in [opt_enc_x, opt_dec_y, opt_dis]:
            trainer.extend(CosineShift(lr_target,
                                       args.epoch // args.lr_drop,
                                       optimizer=e),
                           trigger=(1, 'epoch'))
    else:
        for e in [opt_enc_x, opt_dec_y, opt_dis]:
            #trainer.extend(extensions.LinearShift('eta', (1.0,0.0), (decay_start_iter,decay_end_iter), optimizer=e))
            trainer.extend(extensions.ExponentialShift('lr', 0.33,
                                                       optimizer=e),
                           trigger=(args.epoch // args.lr_drop, 'epoch'))

    # evaluation
    vis_folder = os.path.join(outdir, "vis")
    os.makedirs(vis_folder, exist_ok=True)
    if not args.vis_freq:
        args.vis_freq = max(len(train_d) // 2, 50)
    trainer.extend(VisEvaluator({
        "test": test_iter,
        "train": test_iter_gt
    }, {
        "enc_x": enc_x,
        "dec_y": dec_y
    },
                                params={
                                    'vis_out': vis_folder,
                                    'args': args
                                },
                                device=args.gpu),
                   trigger=(args.vis_freq, 'iteration'))

    # ChainerUI: removed until ChainerUI updates to be compatible with Chainer 6.0
    trainer.extend(CommandsExtension())

    # Run the training
    print("\nresults are saved under: ", outdir)
    save_args(args, outdir)
    trainer.run()
def do_train(rawdata, charcounts, maxlens, unique_onehotvals):
    n_batches = 2000
    mb_size = 128
    lr = 2.0e-4
    momentum = 0.5
    cnt = 0
    latent_dim = 32 #24#
    recurrent_hidden_size = 24

    epoch_len = 8
    max_veclen = 0.0
    patience = 12 * epoch_len
    patience_duration = 0

    # mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)

    input_dict = {}
    input_dict['discrete'] = discrete_cols
    input_dict['continuous'] = continuous_cols

    input_dict['onehot'] = {}
    for k in onehot_cols:
        dim = int(np.ceil(np.log(len(unique_onehotvals[k])) / np.log(2.0)))
        input_dict['onehot'][k] = dim

    if len(charcounts) > 0:
        text_dim = int(np.ceil(np.log(len(charcounts)) / np.log(2.0)))
        input_dict['text'] = {t: text_dim for t in text_cols}
    else:
        text_dim = 0
        input_dict['text'] = {}

    data = Dataseq(rawdata, charcounts, input_dict, unique_onehotvals, maxlens)
    data_idx = np.arange(data.__len__())
    np.random.shuffle(data_idx)
    n_folds = 6
    fold_size = 1.0 * data.__len__() / n_folds
    folds = [data_idx[int(i * fold_size):int((i + 1) * fold_size)] for i in range(6)]

    fold_groups = {}
    fold_groups[0] = {'train': [0, 1, 2, 4], 'es': [3], 'val': [5]}
    fold_groups[1] = {'train': [0, 2, 3, 5], 'es': [1], 'val': [4]}
    fold_groups[2] = {'train': [1, 3, 4, 5], 'es': [2], 'val': [0]}
    fold_groups[3] = {'train': [0, 2, 3, 4], 'es': [5], 'val': [1]}
    fold_groups[4] = {'train': [0, 1, 3, 5], 'es': [4], 'val': [2]}
    fold_groups[5] = {'train': [1, 2, 4, 5], 'es': [0], 'val': [3]}

    for fold in range(1):

        train_idx = np.array(list(itertools.chain.from_iterable([folds[i] for i in fold_groups[fold]['train']])))
        es_idx = np.array(list(itertools.chain.from_iterable([folds[i] for i in fold_groups[fold]['es']])))
        val_idx = np.array(folds[fold_groups[fold]['val'][0]])

        train = Subset(data, train_idx)
        es = Subset(data, es_idx)
        val = Subset(data, val_idx)

        kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
        train_iter = torch.utils.data.DataLoader(train, batch_size=int(mb_size/1), shuffle=True, **kwargs)
        train_iter_unshuffled = torch.utils.data.DataLoader(train, batch_size=mb_size, shuffle=False, **kwargs)
        es_iter = torch.utils.data.DataLoader(es, batch_size=mb_size, shuffle=True, **kwargs)
        val_iter = torch.utils.data.DataLoader(val, batch_size=mb_size, shuffle=True, **kwargs)

        embeddings = {}
        reverse_embeddings = {}
        onehot_embedding_weights = {}
        onehot_embedding_spread = {}
        for k in onehot_cols:
            dim = input_dict['onehot'][k]
            onehot_embedding_weights[k] = net.get_embedding_weight(len(unique_onehotvals[k]), dim, use_cuda=use_cuda)
            embeddings[k] = nn.Embedding(len(unique_onehotvals[k]), dim, _weight=onehot_embedding_weights[k])
            reverse_embeddings[k] = net.EmbeddingToIndex(len(unique_onehotvals[k]), dim, _weight=onehot_embedding_weights[k])

        if text_dim > 0:
            text_embedding_weights = net.get_embedding_weight(len(charcounts) + 1, text_dim, use_cuda=use_cuda)
            text_embedding = nn.Embedding(len(charcounts) + 1, text_dim, _weight=text_embedding_weights)
            text_embeddingtoindex = net.EmbeddingToIndex(len(charcounts) + 1, text_dim, _weight=text_embedding_weights)
            for k in text_cols:
                embeddings[k] = text_embedding
                reverse_embeddings[k] = text_embeddingtoindex

        enc = net.Encoder(input_dict, dim=latent_dim, recurrent_hidden_size=recurrent_hidden_size)
        dec = net.Decoder(input_dict, maxlens, dim=latent_dim, recurrent_hidden_size=recurrent_hidden_size)

        if use_cuda:
            embeddings = {k: embeddings[k].cuda() for k in embeddings.keys()}
            enc.cuda()
            dec.cuda()


        #print(enc.parameters)
        #print(dec.parameters)


        #contrastivec = contrastive.ContrastiveLoss(margin=margin)
        logloss = contrastive.GaussianOverlap()


        #solver = optim.RMSprop([p for em in embeddings.values() for p in em.parameters()] +  [p for p in enc.parameters()] + [p for p in dec.parameters()], lr=lr)
        #solver = optim.Adam(
        #    [p for em in embeddings.values() for p in em.parameters()] + [p for p in enc.parameters()] + [p for p in
        #                                                                                                  dec.parameters()],
        #    lr=lr)

        solver = optim.RMSprop(
            [p for em in embeddings.values() for p in em.parameters()] + [p for p in enc.parameters()] + [p for p in
                                                                                                          dec.parameters()],
            lr=lr, momentum=momentum)

        Tsample = next(es_iter.__iter__())
        if use_cuda:
            Tsample = {col: Variable(tt).cuda() for col, tt in Tsample.items()}
        else:
            Tsample = {col: Variable(tt) for col, tt in Tsample.items()}

        print({col: tt[0] for col, tt in Tsample.items()})

        print('starting training')
        loss = 0.0
        loss0 = 0.0
        loss1 = 0.0
        loss2 = 0.0
        loss3 = 0.0

        logger_df = pd.DataFrame(columns=['iter', 'train_loss', 'train_veclen', 'es_veclen', 'Survived_correct', 'Survived_false'])

        for it in range(n_batches):
            # X = Variable(torch.tensor(np.array([[1,2,4], [4,1,9]]))).cuda()
            T = next(iter(train_iter))
            #for col, val in T.items():
            #    T[col] = torch.cat((val, val, val, val), 0)

            T, X, X2, mu, logvar, mu2, mu2d, mu_tm, logvar2, logvar2d = calc_mus(T, embeddings, reverse_embeddings, enc, dec)
            enc_loss, enc_loss0, enc_loss1, enc_loss2, enc_loss3 = calc_losses(T, embeddings, mu, logvar, mu2, mu2d, mu_tm, logvar2, logvar2d, logloss)

            enc_loss.backward()
            solver.step()

            enc.zero_grad()
            dec.zero_grad()
            for col in embeddings.keys():
                embeddings[col].zero_grad()

            loss += enc_loss.data.cpu().numpy()
            loss0 += enc_loss0.data.cpu().numpy()
            loss1 += enc_loss1.data.cpu().numpy()
            loss2 += enc_loss2.data.cpu().numpy()
            loss3 += enc_loss3.data.cpu().numpy()
            veclen = torch.mean(torch.pow(mu, 2))
            if it % epoch_len == 0:
                print(it, loss/epoch_len, loss0/epoch_len, loss1/epoch_len, loss2/epoch_len, loss3/epoch_len, veclen.data.cpu().numpy()) #enc_loss.data.cpu().numpy(),


                if use_cuda:
                    mu = torch.zeros(len(train), mu.size(1)).cuda()
                    logvar = torch.zeros(len(train), mu.size(1)).cuda()
                    mu2 = torch.zeros(len(train), mu.size(1)).cuda()
                    mu2d = torch.zeros(len(train), mu.size(1)).cuda()
                    mu_tm = torch.zeros((len(train),) + mu_tm.size()[1:]).cuda()
                    logvar2 = torch.zeros(len(train), mu.size(1)).cuda()
                    logvar2d = torch.zeros(len(train), mu.size(1)).cuda()
                else:
                    mu = torch.zeros(len(train), mu.size(1))
                    logvar = torch.zeros(len(train), mu.size(1))
                    mu2 = torch.zeros(len(train), mu.size(1))
                    mu2d = torch.zeros(len(train), mu.size(1))
                    mu_tm = torch.zeros((len(train),) + mu_tm.size()[1:])
                    logvar2 = torch.zeros(len(train), mu.size(1))
                    logvar2d = torch.zeros(len(train), mu.size(1))

                s = 0
                for T0 in train_iter_unshuffled:
                    e = s + T0[to_predict[0]].size(0)
                    if s == 0:
                        T = {col : torch.zeros((len(train),) + val.size()[1:], dtype=val.dtype) for col, val in T0.items()}

                    T0, blah, bblah, mu[s:e], logvar[s:e], mu2[s:e], mu2d[s:e], mu_tm[s:e], logvar2[s:e], logvar2d[s:e] = calc_mus(T0, embeddings, reverse_embeddings,  enc, dec, mode='val')
                    for col, val in T0.items():
                        T[col][s:e] = T0[col]

                    s = e

                enc_loss, enc_loss0, enc_loss1, enc_loss3, enc_loss3 = calc_losses(T, embeddings, mu, logvar, mu2, mu2d, mu_tm, logvar2, logvar2d, logloss, lookfordups=False)
                vl = torch.mean(torch.pow(mu, 2))

                print(f'train enc loss {enc_loss}')
                print(f'train veclen {vl}')
                print(f'mean train logvar {torch.mean(logvar)}')
                logger_df.loc[int(it/epoch_len), ['iter', 'train_loss', 'train_veclen']] = [it, enc_loss.data.cpu().numpy(), vl.data.cpu().numpy()]

                if use_cuda:
                    mu = torch.zeros(len(es), mu.size(1)).cuda()
                    logvar = torch.zeros(len(es), mu.size(1)).cuda()
                    mu2 = torch.zeros(len(es), mu.size(1)).cuda()
                    mu2d = torch.zeros(len(es), mu.size(1)).cuda()
                else:
                    mu = torch.zeros(len(es), mu.size(1))
                    logvar = torch.zeros(len(es), mu.size(1))
                    mu2 = torch.zeros(len(es), mu.size(1))
                    mu2d = torch.zeros(len(es), mu.size(1))

                s = 0
                targets = {}
                for T0 in es_iter:
                    e = s + T0[to_predict[0]].size(0)
                    if s == 0:
                        T = {col : torch.zeros((len(es),) + val.size()[1:], dtype=val.dtype) for col, val in T0.items()}
                        correct = {col: np.zeros((len(es),) + val.size()[1:]) for col, val in T0.items()}
                        actual = {col: np.zeros((len(es),) + val.size()[1:]) for col, val in T0.items()}

                    Xsample = {}
                    for col, tt in T0.items():
                        if use_cuda:
                            tt = Variable(tt).cuda()
                        else:
                            tt = Variable(tt)

                        if col in embeddings.keys():
                            Xsample[col] = embeddings[col](tt)
                        else:
                            Xsample[col] = tt.float()



                    for col in to_predict:
                        targets[col] = tt
                        Xsample[col] = 0.0 * Xsample[col]


                    mu[s:e], logvar[s:e] = enc(Xsample)

                    X2sample = dec(mu[s:e])
                    T2sample = discretize(X2sample, embeddings, maxlens)

                    mu2[s:e], _ = enc(X2sample)


                    T2 = {}
                    X2dsample = {col: (1.0 * tt).detach() for col, tt in X2sample.items()}
                    for col in continuous_cols:
                        if col in to_predict:
                            correct[col][s:e] = np.abs(X2sample[col].data.cpu().numpy().reshape(-1) - targets[
                                col].data.cpu().numpy().reshape(-1))
                            actual[col][s:e] = targets[col].data.cpu().numpy().reshape(-1)
                        else:
                            correct[col][s:e] = np.abs(X2sample[col].data.cpu().numpy().reshape(-1) - T0[
                                col].data.cpu().numpy().reshape(-1))
                            actual[col][s:e] = T0[col].data.cpu().numpy().reshape(-1)


                    for col, embedding in embeddings.items():
                        # T2[col] = reverse_embeddings[col](X2sample[col])
                        X2dsample[col] = embeddings[col](T2sample[col].detach())

                        if col in to_predict:
                            correct[col][s:e] = np.abs(T2sample[col].data.cpu().numpy() == targets[col].data.cpu().numpy())
                            actual[col][s:e] = targets[col].data.cpu().numpy().reshape(-1)
                        else:
                            correct[col][s:e] = np.abs(T2sample[col].data.cpu().numpy() == T0[col].data.cpu().numpy())
                            actual[col][s:e] = T0[col].data.cpu().numpy().reshape(-1)

                    mu2d[s:e], _ = enc(X2dsample)

                    s = e

                #enc_loss, enc_loss0, enc_loss1, enc_loss3, enc_loss3 = calc_losses(T, embeddings, mu, logvar, mu2, mu2d, mu_tm, logvar2, logloss, lookfordups=False)
                #print(f'es enc loss {enc_loss}')
                vl = torch.mean(torch.pow(mu, 2))

                print(f'es veclen {vl}')
                print(f'mean es logvar {torch.mean(logvar)}')
                logger_df.loc[int(it/epoch_len), ['es_veclen', 'Survived_correct', 'Survived_false']] = vl.data.cpu().numpy(), np.mean(correct['Survived']), np.mean(actual['Survived']==0)


                for col in continuous_cols:
                    #print(np.abs(T0[col].data.cpu().numpy().reshape(-1) - T2sample[col].data.cpu().numpy().reshape(-1)))
                    print(f'% {col} mae: {np.mean(correct[col])}')

                for col in onehot_cols:
                    print(f'% {col} correct: {np.mean(correct[col])} {np.mean(actual[col]==0)}')



                '''
                for col in continuous_cols:
                    mae = np.mean(np.abs(X[col].data.cpu().numpy() - X2[col].data.cpu().numpy()))
                    mse = np.mean(np.square(X[col].data.cpu().numpy() - X2[col].data.cpu().numpy()))
                    print(f'train mae, mse {col} {mae} {mse}')
                    mae = np.mean(np.abs(Xsample[col].data.cpu().numpy() - X2sample[col].data.cpu().numpy()))
                    mse = np.mean(np.square(Xsample[col].data.cpu().numpy() - X2sample[col].data.cpu().numpy()))
                    print(f'val mae, mse {col} {mae} {mse}')

                print({col: tt[0:2].data.cpu().numpy() for col, tt in T2sample.items()})

                if 'Survived' in onehot_cols:
                    print('% survived correct: ', np.mean(T2sample['Survived'].data.cpu().numpy()==Tsample['Survived'].data.cpu().numpy()), np.mean(Tsample['Survived'].data.cpu().numpy()==np.ones_like(Tsample['Survived'].data.cpu().numpy())))

                if 'Sex' in onehot_cols:
                    print('% sex correct: ', np.mean(T2sample['Sex'].data.cpu().numpy()==Tsample['Sex'].data.cpu().numpy()), np.mean(Tsample['Sex'].data.cpu().numpy()==np.ones_like(Tsample['Sex'].data.cpu().numpy())))

                if 'Embarked' in onehot_cols:
                    print('% Embarked correct: ', np.mean(T2sample['Embarked'].data.cpu().numpy()==Tsample['Embarked'].data.cpu().numpy()) )
                    print(onehot_embedding_weights['Embarked'])

                if 'Pclass' in onehot_cols:
                    print('% Pclass correct: ',
                          np.mean(T2sample['Pclass'].data.cpu().numpy() == Tsample['Pclass'].data.cpu().numpy()))

                if 'Cabin' in text_cols:
                    print(embeddings['Cabin'].weight[data.charindex['1']])

                
                if 'Pclass' in onehot_cols:
                    diff = torch.mean(torch.pow(embeddings['Pclass'].weight - reverse_embeddings['Pclass'].weight, 2)).data.cpu().numpy()
                    print(f'diff plcass emb and reverse_emb: {diff}')
                    print(embeddings['Pclass'].weight.data.cpu().numpy())
                '''





                loss = 0.0
                loss0 = 0.0
                loss1 = 0.0
                loss2 = 0.0
                loss3 = 0.0
                #print(T2.data.cpu()[0, 0:30].numpy())

        logger_df.to_csv('logger_'+str(fold)+'.csv', index=False)
Example #3
0
    iterator = chainer.iterators.MultithreadIterator(
        dataset, args.batch_size, n_threads=3, repeat=False,
        shuffle=False)  ## best performance
    #    iterator = chainer.iterators.SerialIterator(dataset, args.batch_size,repeat=False, shuffle=False)

    ## load generator models
    if "gen" in args.load_models:
        gen = net.Generator(args)
        print('Loading {:s}..'.format(args.load_models))
        serializers.load_npz(args.load_models, gen)
        if args.gpu >= 0:
            gen.to_gpu()
        xp = gen.xp
        is_AE = False
    elif "enc" in args.load_models:
        enc = net.Encoder(args)
        print('Loading {:s}..'.format(args.load_models))
        serializers.load_npz(args.load_models, enc)
        dec = net.Decoder(args)
        modelfn = args.load_models.replace('enc_x', 'dec_y')
        modelfn = modelfn.replace('enc_y', 'dec_x')
        print('Loading {:s}..'.format(modelfn))
        serializers.load_npz(modelfn, dec)
        if args.gpu >= 0:
            enc.to_gpu()
            dec.to_gpu()
        xp = enc.xp
        is_AE = True
    else:
        gen = F.identity
        xp = np
        dataset, args.batch_size, n_threads=3, repeat=False,
        shuffle=False)  ## best performance
    #    iterator = chainer.iterators.SerialIterator(dataset, args.batch_size,repeat=False, shuffle=False)

    ## load generator models
    if "gen" in args.load_models:
        gen = net.Generator(args)
        print('Loading {:s}..'.format(args.load_models))
        serializers.load_npz(args.load_models, gen)
        if args.gpu >= 0:
            gen.to_gpu()
        xp = gen.xp
        is_AE = False
    elif "enc" in args.load_models:
        args.gen_nblock = args.gen_nblock // 2  # to match ordinary cycleGAN
        enc = net.Encoder(args)
        print('Loading {:s}..'.format(args.load_models))
        serializers.load_npz(args.load_models, enc)
        dec = net.Decoder(args)
        modelfn = args.load_models.replace('enc_x', 'dec_y')
        modelfn = modelfn.replace('enc_y', 'dec_x')
        print('Loading {:s}..'.format(modelfn))
        serializers.load_npz(modelfn, dec)
        if args.gpu >= 0:
            enc.to_gpu()
            dec.to_gpu()
        xp = enc.xp
        is_AE = True
    else:
        print("Specify a learnt model.")
        exit()
Example #5
0
    def getenc(self):
        #dummy_ds = Dataseq(self.data, self.calendar, self.sell_prices, self.cat2val, 'd_1', 'd_1885', 'd_1913')
        #dummy_enc, dummy_dec, dummy_y, dummy_w = dummy_ds.__getitem__(0)

        return net.Encoder(self.hidden_dim, self.cat2val)
Example #6
0
def do_train():
    epoch_len = 8
    n_batches = 1600
    lr = 1.0e-4
    mb_size = 128
    latent_dim = 32

    ARloss = contrastive.AlwaysRight()

    train = dsets.MNIST(
        root='../data/',
        train=True,
        # transform = transforms.Compose([transforms.RandomRotation(10), transforms.ToTensor()]),
        transform=transforms.Compose([transforms.ToTensor()]),
        download=True)

    train_data = Dataseq(train.train_data, train.train_labels,
                         np.arange(train.train_labels.size(0)))

    train_iter = torch.utils.data.DataLoader(train_data,
                                             batch_size=mb_size,
                                             shuffle=True)

    enc = net.Encoder(dim=latent_dim)
    if use_cuda:
        enc = enc.cuda()

    if use_cuda:
        mu = torch.zeros(mb_size, 3, latent_dim).cuda()
        logvar = torch.zeros(mb_size, 3, latent_dim).cuda()
    else:
        mu = torch.zeros(mb_size, 3, latent_dim)
        logvar = torch.zeros(mb_size, 3, latent_dim)

    solver = optim.RMSprop([p for p in enc.parameters()], lr=lr)

    for it in range(n_batches):
        X, idx, y = next(iter(train_iter))
        if len(set(idx)) != mb_size:
            print(len(set(idx)))
        #print(y[0:5])

        if use_cuda:
            X = Variable(X).cuda()
        else:
            X = Variable(X)

        #mu[:, 0], logvar[:, 0] = enc(T)
        #mu[:, 0] = enc.reparameterize(mu[:, 0], logvar[:, 0])
        #mu[:, 1], logvar[:, 1] = torch.cat((mu[3:, 0], mu[0:3, 0]), dim=0), torch.cat((logvar[3:, 0], logvar[0:3, 0]), dim=0)
        #mu[:, 2], logvar[:, 2] = torch.cat((mu[5:, 0], mu[0:5, 0]), dim=0), torch.cat((logvar[5:, 0], logvar[0:5, 0]), dim=0)

        mu0, logvar0 = enc(X)
        mu0a = enc.reparameterize(mu0, logvar0)
        mu0b = enc.reparameterize(mu0, logvar0)
        mu1 = torch.cat((mu0a[3:], mu0a[0:3]), dim=0)
        mu2 = torch.cat((mu0b[5:], mu0b[0:5]), dim=0)
        mu = torch.cat((mu0a.unsqueeze(1), mu1.unsqueeze(1), mu2.unsqueeze(1)),
                       1)

        if use_cuda:
            target = torch.zeros(mb_size, 3).cuda()
        else:
            target = torch.zeros(mb_size, 3)

        loss = ARloss(mu, target)
        loss += 1.0 / 4.0 * torch.mean(torch.pow(mu, 2))
        loss += 1.0 / 4.0 * torch.mean(torch.exp(logvar) - logvar)

        mu = torch.cat(
            (mu0a.unsqueeze(1), mu0b.unsqueeze(1), mu2.unsqueeze(1)), 1)
        target[:, 2] = 1
        loss += 0.5 * ARloss(mu, target)

        loss.backward()
        solver.step()

        enc.zero_grad()

        if (it + 1) % epoch_len == 0:
            print(it + 1,
                  loss.data.cpu().numpy(),
                  torch.mean(torch.pow(mu0, 2)).data.cpu().numpy())

    return enc
Example #7
0
def train(rawdata, charcounts, maxlens, unique_onehotvals):
    mb_size = 256
    lr = 2.0e-4
    cnt = 0
    latent_dim = 32
    recurrent_hidden_size = 24

    epoch_len = 8
    max_veclen = 0.0
    patience = 12 * epoch_len
    patience_duration = 0

    # mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)

    input_dict = {}
    input_dict['discrete'] = discrete_cols
    input_dict['continuous'] = continuous_cols

    input_dict['onehot'] = {}
    for k in onehot_cols:
        dim = int(np.ceil(np.log(len(unique_onehotvals[k])) / np.log(2.0)))
        input_dict['onehot'][k] = dim

    if len(charcounts) > 0:
        text_dim = int(np.ceil(np.log(len(charcounts)) / np.log(2.0)))
        input_dict['text'] = {t: text_dim for t in text_cols}
    else:
        text_dim = 0
        input_dict['text'] = {}

    data = Dataseq(rawdata, charcounts, input_dict, unique_onehotvals, maxlens)
    data_idx = np.arange(data.__len__())
    np.random.shuffle(data_idx)
    n_folds = 6
    fold_size = 1.0 * data.__len__() / n_folds
    folds = [data_idx[int(i * fold_size):int((i + 1) * fold_size)] for i in range(6)]

    fold_groups = {}
    fold_groups[0] = {'train': [0, 1, 2, 4], 'es': [3], 'val': [5]}
    fold_groups[1] = {'train': [0, 2, 3, 5], 'es': [1], 'val': [4]}
    fold_groups[2] = {'train': [1, 3, 4, 5], 'es': [2], 'val': [0]}
    fold_groups[3] = {'train': [0, 2, 3, 4], 'es': [5], 'val': [1]}
    fold_groups[4] = {'train': [0, 1, 3, 5], 'es': [4], 'val': [2]}
    fold_groups[5] = {'train': [1, 2, 4, 5], 'es': [0], 'val': [3]}

    for fold in range(1):

        train_idx = np.array(list(itertools.chain.from_iterable([folds[i] for i in fold_groups[fold]['train']])))
        es_idx = np.array(list(itertools.chain.from_iterable([folds[i] for i in fold_groups[fold]['es']])))
        val_idx = np.array(folds[fold_groups[fold]['val'][0]])

        train = Subset(data, train_idx)
        es = Subset(data, es_idx)
        val = Subset(data, val_idx)

        kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
        train_iter = torch.utils.data.DataLoader(train, batch_size=mb_size, shuffle=True, **kwargs)
        es_iter = torch.utils.data.DataLoader(es, batch_size=mb_size, shuffle=True, **kwargs)
        val_iter = torch.utils.data.DataLoader(val, batch_size=mb_size, shuffle=True, **kwargs)

        embeddings = {}
        reverse_embeddings = {}
        onehot_embedding_weights = {}
        for k in onehot_cols:
            dim = input_dict['onehot'][k]
            onehot_embedding_weights[k] = net.get_embedding_weight(len(unique_onehotvals[k]), dim)
            if use_cuda:
                onehot_embedding_weights[k] = onehot_embedding_weights[k].cuda()
            #embeddings[k] = nn.Embedding(len(unique_onehotvals[k]), dim, max_norm=1.0)
            embeddings[k] = nn.Embedding(len(unique_onehotvals[k]), dim, _weight=onehot_embedding_weights[k], max_norm=1.0)
            reverse_embeddings[k] = net.EmbeddingToIndex(len(unique_onehotvals[k]), dim, _weight=onehot_embedding_weights[k])

        if text_dim > 0:
            text_embedding_weights = net.get_embedding_weight(len(charcounts) + 1, text_dim)
            if use_cuda:
                text_embedding_weights = text_embedding_weights.cuda()
            #text_embedding = nn.Embedding(len(charcounts)+1, text_dim, max_norm=1.0)
            text_embedding = nn.Embedding(len(charcounts) + 1, text_dim, _weight=text_embedding_weights, max_norm=1.0)
            text_embeddingtoindex = net.EmbeddingToIndex(len(charcounts) + 1, text_dim, _weight=text_embedding_weights)
            for k in text_cols:
                embeddings[k] = text_embedding
                reverse_embeddings[k] = text_embeddingtoindex

        enc = net.Encoder(input_dict, dim=latent_dim, recurrent_hidden_size=recurrent_hidden_size)
        dec = net.Decoder(input_dict, maxlens, dim=latent_dim, recurrent_hidden_size=recurrent_hidden_size)

        if use_cuda:
            embeddings = {k: embeddings[k].cuda() for k in embeddings.keys()}
            reverse_embeddings = {k: reverse_embeddings[k].cuda() for k in embeddings.keys()}
            enc.cuda()
            dec.cuda()


        #print(enc.parameters)
        #print(dec.parameters)


        contrastivec = contrastive.ContrastiveLoss(margin=margin)


        #solver = optim.RMSprop([p for em in embeddings.values() for p in em.parameters()] +  [p for p in enc.parameters()] + [p for p in dec.parameters()], lr=lr)
        solver = optim.Adam(
            [p for em in embeddings.values() for p in em.parameters()] + [p for p in enc.parameters()] + [p for p in
                                                                                                          dec.parameters()],
            lr=lr)

        Tsample = next(es_iter.__iter__())
        if use_cuda:
            Tsample = {col: Variable(tt[0:128]).cuda() for col, tt in Tsample.items()}
        else:
            Tsample = {col: Variable(tt[0:128]) for col, tt in Tsample.items()}

        print({col: tt[0] for col, tt in Tsample.items()})

        print('starting training')
        loss = 0.0
        for it in range(1000000):
            # X = Variable(torch.tensor(np.array([[1,2,4], [4,1,9]]))).cuda()
            batch_idx, T = next(enumerate(train_iter))
            if use_cuda:
                T = {col: Variable(tt).cuda() for col, tt in T.items()}
            else:
                T = {col: Variable(tt) for col, tt in T.items()}

            X = {}
            for col, tt in T.items():
                if col in embeddings.keys():
                    X[col] = embeddings[col](tt)
                else:
                    X[col] = tt.float()

            mu = enc(X)
            X2 = dec(mu)

            T2 = {}
            X2d = {col: (1.0 * tt).detach() for col, tt in X2.items()}


            for col, embedding in embeddings.items():
                T2[col] = reverse_embeddings[col](X2[col])
                X2[col] = 0.5*X2[col] + 0.5*embeddings[col](T2[col])
                X2d[col] = embeddings[col](T2[col].detach())



            '''
            X2d = {col: (1.0*tt).detach() for col, tt in X2.items()}
            T2 = discretize(X2d, embeddings, maxlens)
            for col, embedding in embeddings.items():
                X2d[col] = embeddings[col](T2[col].detach())
            '''
            '''
            T2 = discretize(X2, embeddings, maxlens)
            X2d = {col: (1.0*tt).detach() for col, tt in X2.items()}

            for col, embedding in embeddings.items():
                X2[col] = embeddings[col](T2[col]) #+0.05 X2[col]
                X2d[col] = embeddings[col](T2[col].detach())
            '''


            mu2 = enc(X2)
            mu2 = mu2.view(mb_size, -1)

            mu2d = enc(X2d)

            mu2d = mu2d.view(mb_size, -1)


            mu = mu.view(mb_size, -1)

            are_same = are_equal({col: x[::2] for col, x in T.items()}, {col: x[1::2] for col, x in T.items()})
            #print('f same ', torch.mean(torch.mean(are_same, 1)))
            #enc_loss = contrastivec(mu2[::2], mu2[1::2], torch.zeros(int(mb_size / 2)).cuda())
            enc_loss = contrastivec(mu[::2], mu[1::2], are_same)
            #enc_loss += 0.5*contrastivec(mu2[::2], mu2[1::2], are_same)
            #enc_loss += 0.5 * contrastivec(mu[::2], mu2[1::2], are_same)
            enc_loss += 1.0*contrastivec(mu, mu2, torch.ones(mb_size).cuda())
            enc_loss += 2.0*contrastivec(mu, mu2d, torch.zeros(mb_size).cuda())
            #enc_loss += 1.0 * contrastivec(mu2d[0::2], mu2d[1::2], torch.ones(int(mb_size/2)).cuda())
            #enc_loss += 1.0 * contrastivec(mu2d[::2], mu2d[1::2], torch.ones(int(mb_size / 2)).cuda())
            #enc_loss += 0.5 * contrastivec(mu2d[::2], mu2d[1::2], torch.ones(int(mb_size/2)).cuda())

            '''
            adotb = torch.matmul(mu, mu.permute(1, 0))  # batch_size x batch_size
            adota = torch.matmul(mu.view(-1, 1, latent_dim), mu.view(-1, latent_dim, 1))  # batch_size x 1 x 1
            diffsquares = (adota.view(-1, 1).repeat(1, mb_size) + adota.view(1, -1).repeat(mb_size, 1) - 2 * adotb) / latent_dim

            # did I f**k up something here? diffsquares can apparently be less than 0....
            mdist = torch.sqrt(torch.clamp(torch.triu(diffsquares, diagonal=1),  min=0.0))
            mdist = torch.clamp(margin - mdist, min=0.0)
            number_of_pairs = mb_size * (mb_size - 1) / 2

            enc_loss = 0.5 * torch.sum(torch.triu(torch.pow(mdist, 2), diagonal=1)) / number_of_pairs

            target = torch.ones(mu.size(0), 1)
            if use_cuda:
                target.cuda()
            enc_loss += contrastivec(mu, mu2, target.cuda())

            target = torch.zeros(mu.size(0), 1)
            if use_cuda:
                target.cuda()
            enc_loss += 2.0 * contrastivec(mu, mu2d, target.cuda())
            '''


            enc_loss.backward()
            solver.step()

            enc.zero_grad()
            dec.zero_grad()
            for col in embeddings.keys():
                embeddings[col].zero_grad()

            loss += enc_loss.data.cpu().numpy()
            veclen = torch.mean(torch.pow(mu, 2))
            if it % epoch_len == 0:
                print(it, loss/epoch_len, veclen.data.cpu().numpy()) #enc_loss.data.cpu().numpy(),

                Xsample = {}
                for col, tt in Tsample.items():
                    if col in embeddings.keys():
                        Xsample[col] = embeddings[col](tt)
                    else:
                        Xsample[col] = tt.float()

                mu = enc(Xsample)
                X2sample = dec(mu)
                X2sampled = {col: tt.detach() for col, tt in X2sample.items()}
                T2sample = discretize(X2sample, embeddings, maxlens)

                mu2 = enc(X2sample)
                mu2d = enc(X2sampled)


                if 'Fare' in continuous_cols and 'Age' in continuous_cols:
                    print([np.mean(np.abs(Xsample[col].data.cpu().numpy()-X2sample[col].data.cpu().numpy())) for col in ['Fare', 'Age']])

                print({col: tt[0:2].data.cpu().numpy() for col, tt in T2sample.items()})

                if 'Survived' in onehot_cols:
                    print('% survived correct: ', np.mean(T2sample['Survived'].data.cpu().numpy()==Tsample['Survived'].data.cpu().numpy()), np.mean(Tsample['Survived'].data.cpu().numpy()==np.ones_like(Tsample['Survived'].data.cpu().numpy())))

                if 'Cabin' in text_cols:
                    print(embeddings['Cabin'].weight[data.charindex['1']])



                are_same = are_equal({col: x[::2] for col, x in Tsample.items()}, {col: x[1::2] for col, x in Tsample.items()})
                # print('f same ', torch.mean(torch.mean(are_same, 1)))
                # enc_loss = contrastivec(mu2[::2], mu2[1::2], torch.zeros(int(mb_size / 2)).cuda())
                #es_loss = contrastivec(mu[::2], mu[1::2], are_same)
                # enc_loss += 0.25*contrastivec(mu2[::2], mu2[1::2], are_same)
                # enc_loss += 0.5 * contrastivec(mu[::2], mu2[1::2], are_same)
                es_loss = 1.0 * contrastivec(mu, mu2, torch.ones(mu.size(0)).cuda())
                #es_loss += 2.0 * contrastivec(mu, mu2d, torch.zeros(mu.size(0)).cuda())

                #print('mean mu ', torch.mean(torch.pow(mu, 2)))
                print('es loss ', es_loss)

                loss = 0.0
                #print(T2.data.cpu()[0, 0:30].numpy())
Example #8
0
def main():
    parser = argparse.ArgumentParser(description='VAE MNIST')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=100,
                        help='number of epochs to learn')
    parser.add_argument('--batchsize',
                        '-b',
                        type=int,
                        default=100,
                        help='learning minibatch size')
    parser.add_argument('--dimz',
                        '-z',
                        type=int,
                        default=20,
                        help='dimention of encoded vector')
    parser.add_argument('--out',
                        '-o',
                        type=str,
                        default='model',
                        help='path to the output directory')
    args = parser.parse_args()

    if not os.path.exists(args.out):
        os.makedirs(args.out)

    print(args)

    enc = net.Encoder(784, args.dimz, 500)
    dec = net.Decoder(784, args.dimz, 500)

    if args.gpu >= 0:
        cuda.get_device(args.gpu).use()
        enc.to_gpu()
        dec.to_gpu()
    xp = np if args.gpu < 0 else cuda.cupy

    opt_enc = chainer.optimizers.Adam()
    opt_dec = chainer.optimizers.Adam()
    opt_enc.setup(enc)
    opt_dec.setup(dec)

    train, _ = chainer.datasets.get_mnist(withlabel=False)
    train_iter = chainer.iterators.SerialIterator(train,
                                                  args.batchsize,
                                                  shuffle=True)

    updater = VAEUpdater(models=(enc, dec),
                         iterators={'main': train_iter},
                         optimizers={
                             'enc': opt_enc,
                             'dec': opt_dec
                         },
                         device=args.gpu,
                         params={})
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    trainer.extend(extensions.dump_graph('loss'))
    trainer.extend(extensions.snapshot(), trigger=(args.epoch, 'epoch'))

    log_keys = ['epoch', 'loss', 'rec_loss']
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.PrintReport(log_keys))
    trainer.extend(extensions.ProgressBar())

    trainer.run()
Example #9
0
def main():
    args = arguments()
    out = os.path.join(args.out, dt.now().strftime('%m%d_%H%M_AE'))
    print(args)
    print(out)
    save_args(args, out)
    args.dtype = dtypes[args.dtype]
    args.dis_activation = activation[args.dis_activation]
    args.gen_activation = activation[args.gen_activation]
    args.gen_out_activation = activation[args.gen_out_activation]
    args.gen_nblock = args.gen_nblock // 2  # to match ordinary cycleGAN

    if args.imgtype=="dcm":
        from dataset_dicom import DatasetOutMem as Dataset 
    else:
        from dataset_jpg import DatasetOutMem as Dataset   


    if not chainer.cuda.available:
        print("CUDA required")

    if len(args.gpu)==1 and args.gpu[0] >= 0:
        chainer.cuda.get_device_from_id(args.gpu[0]).use()

    # Enable autotuner of cuDNN
    chainer.config.autotune = True
    chainer.config.dtype = args.dtype
    chainer.print_runtime_info()
    # Turn off type check
#    chainer.config.type_check = False
#    print('Chainer version: ', chainer.__version__)
#    print('GPU availability:', chainer.cuda.available)
#    print('cuDNN availablility:', chainer.cuda.cudnn_enabled)

    ## dataset iterator
    print("Setting up data iterators...")
    train_A_dataset = Dataset(
        path=os.path.join(args.root, 'trainA'), baseA=args.HU_base, rangeA=args.HU_range, slice_range=args.slice_range, crop=(args.crop_height,args.crop_width),random=args.random_translate, forceSpacing=0, imgtype=args.imgtype, dtype=args.dtype)
    train_B_dataset = Dataset(
        path=os.path.join(args.root, 'trainB'),  baseA=args.HU_base, rangeA=args.HU_range, slice_range=args.slice_range, crop=(args.crop_height,args.crop_width), random=args.random_translate, forceSpacing=args.forceSpacing, imgtype=args.imgtype, dtype=args.dtype)
    test_A_dataset = Dataset(
        path=os.path.join(args.root, 'testA'), baseA=args.HU_base, rangeA=args.HU_range, slice_range=args.slice_range, crop=(args.crop_height,args.crop_width), random=0, forceSpacing=0, imgtype=args.imgtype, dtype=args.dtype)
    test_B_dataset = Dataset(
        path=os.path.join(args.root, 'testB'),  baseA=args.HU_base, rangeA=args.HU_range, slice_range=args.slice_range, crop=(args.crop_height,args.crop_width), random=0, forceSpacing=args.forceSpacing, imgtype=args.imgtype, dtype=args.dtype)

    args.ch = train_A_dataset.ch
    test_A_iter = chainer.iterators.SerialIterator(test_A_dataset, args.nvis_A, shuffle=False)
    test_B_iter = chainer.iterators.SerialIterator(test_B_dataset, args.nvis_B, shuffle=False)

    
    if args.batch_size > 1:
        train_A_iter = chainer.iterators.MultiprocessIterator(
            train_A_dataset, args.batch_size, n_processes=3, shuffle=not args.conditional_discriminator)
        train_B_iter = chainer.iterators.MultiprocessIterator(
            train_B_dataset, args.batch_size, n_processes=3, shuffle=not args.conditional_discriminator)
    else:
        train_A_iter = chainer.iterators.SerialIterator(
            train_A_dataset, args.batch_size, shuffle=not args.conditional_discriminator)
        train_B_iter = chainer.iterators.SerialIterator(
            train_B_dataset, args.batch_size, shuffle=not args.conditional_discriminator)

    # setup models
    enc_x = net.Encoder(args)
    enc_y = net.Encoder(args)
    dec_x = net.Decoder(args)
    dec_y = net.Decoder(args)
    dis_x = net.Discriminator(args)
    dis_y = net.Discriminator(args)
    dis_z = net.Discriminator(args)
    models = {'enc_x': enc_x, 'enc_y': enc_y, 'dec_x': dec_x, 'dec_y': dec_y, 'dis_x': dis_x, 'dis_y': dis_y, 'dis_z': dis_z}
    optimiser_files = []

    ## load learnt models
    if args.load_models:
        for e in models:
            m = args.load_models.replace('enc_x',e)
            try:
                serializers.load_npz(m, models[e])
                print('model loaded: {}'.format(m))
            except:
                print("couldn't load {}".format(m))
                pass
            optimiser_files.append(m.replace(e,'opt_'+e).replace('dis_',''))

    # select GPU
    if len(args.gpu) == 1:
        for e in models:
            models[e].to_gpu()
        print('use gpu {}, cuDNN {}'.format(args.gpu, chainer.cuda.cudnn_enabled))
    else:
        print("mandatory GPU use: currently only a single GPU can be used")
        exit()

    # Setup optimisers
    def make_optimizer(model, alpha=0.0002, beta1=0.5):
        eps = 1e-5 if args.dtype==np.float16 else 1e-8
        optimizer = chainer.optimizers.Adam(alpha=alpha, beta1=beta1, eps=eps)
        optimizer.setup(model)
        if args.weight_decay>0:
            if args.weight_decay_norm =='l2':
                optimizer.add_hook(chainer.optimizer.WeightDecay(args.weight_decay))
            else:
                optimizer.add_hook(chainer.optimizer_hooks.Lasso(args.weight_decay))
        return optimizer

    opt_enc_x = make_optimizer(enc_x, alpha=args.learning_rate_g)
    opt_dec_x = make_optimizer(dec_x, alpha=args.learning_rate_g)
    opt_enc_y = make_optimizer(enc_y, alpha=args.learning_rate_g)
    opt_dec_y = make_optimizer(dec_y, alpha=args.learning_rate_g)
    opt_x = make_optimizer(dis_x, alpha=args.learning_rate_d)
    opt_y = make_optimizer(dis_y, alpha=args.learning_rate_d)
    opt_z = make_optimizer(dis_z, alpha=args.learning_rate_d)
    optimizers = {'opt_enc_x': opt_enc_x,'opt_dec_x': opt_dec_x,'opt_enc_y': opt_enc_y,'opt_dec_y': opt_dec_y,'opt_x': opt_x,'opt_y': opt_y,'opt_z': opt_z}
    if args.load_optimizer:
        for (m,e) in zip(optimiser_files,optimizers):
            if m:
                try:
                    serializers.load_npz(m, optimizers[e])
                    print('optimiser loaded: {}'.format(m))
                except:
                    print("couldn't load {}".format(m))
                    pass

    # Set up an updater: TODO: multi gpu updater
    print("Preparing updater...")
    updater = Updater(
        models=(enc_x,dec_x,enc_y,dec_y, dis_x, dis_y, dis_z),
        iterator={
            'main': train_A_iter,
            'train_B': train_B_iter,
        },
        optimizer=optimizers,
        converter=convert.ConcatWithAsyncTransfer(),
        device=args.gpu[0],
        params={
            'args': args
        })

    if args.snapinterval<0:
        args.snapinterval = args.lrdecay_start+args.lrdecay_period
    log_interval = (200, 'iteration')
    model_save_interval = (args.snapinterval, 'epoch')
    vis_interval = (args.vis_freq, 'iteration')
    plot_interval = (500, 'iteration')
    
    # Set up a trainer
    print("Preparing trainer...")
    trainer = training.Trainer(updater, (args.lrdecay_start + args.lrdecay_period, 'epoch'), out=out)
    for e in models:
        trainer.extend(extensions.snapshot_object(
            models[e], e+'{.updater.epoch}.npz'), trigger=model_save_interval)
    for e in optimizers:
        trainer.extend(extensions.snapshot_object(
            optimizers[e], e+'{.updater.epoch}.npz'), trigger=model_save_interval)

    log_keys = ['epoch', 'iteration']
    log_keys_cycle = ['opt_enc_x/loss_cycle', 'opt_enc_y/loss_cycle', 'opt_dec_x/loss_cycle',  'opt_dec_y/loss_cycle', 'myval/cycle_y_l1']
    log_keys_d = ['opt_x/loss_real','opt_x/loss_fake','opt_y/loss_real','opt_y/loss_fake','opt_z/loss_x','opt_z/loss_y']
    log_keys_adv = ['opt_enc_y/loss_adv','opt_dec_y/loss_adv','opt_enc_x/loss_adv','opt_dec_x/loss_adv']
    log_keys.extend([ 'opt_dec_y/loss_id'])
    log_keys.extend([ 'opt_enc_x/loss_reg','opt_enc_y/loss_reg', 'opt_dec_x/loss_air','opt_dec_y/loss_air', 'opt_dec_y/loss_tv'])
    log_keys_d.extend(['opt_x/loss_gp','opt_y/loss_gp'])

    log_keys_all = log_keys+log_keys_d+log_keys_adv+log_keys_cycle
    trainer.extend(extensions.LogReport(keys=log_keys_all, trigger=log_interval))
    trainer.extend(extensions.PrintReport(log_keys_all), trigger=log_interval)
    trainer.extend(extensions.ProgressBar(update_interval=20))
    trainer.extend(CommandsExtension())
    ## to dump graph, set -lix 1 --warmup 0
#    trainer.extend(extensions.dump_graph('opt_g/loss_id', out_name='gen.dot'))
#    trainer.extend(extensions.dump_graph('opt_x/loss', out_name='dis.dot'))

    if extensions.PlotReport.available():
        trainer.extend(extensions.PlotReport(log_keys[2:], 'iteration',trigger=plot_interval, file_name='loss.png'))
        trainer.extend(extensions.PlotReport(log_keys_d, 'iteration', trigger=plot_interval, file_name='loss_d.png'))
        trainer.extend(extensions.PlotReport(log_keys_adv, 'iteration', trigger=plot_interval, file_name='loss_adv.png'))
        trainer.extend(extensions.PlotReport(log_keys_cycle, 'iteration', trigger=plot_interval, file_name='loss_cyc.png'))

    ## output filenames of training dataset
    with open(os.path.join(out, 'trainA.txt'),'w') as output:
        output.writelines("\n".join(train_A_dataset.ids))
    with open(os.path.join(out, 'trainB.txt'),'w') as output:
        output.writelines("\n".join(train_B_dataset.ids))
    # archive the scripts
    rundir = os.path.dirname(os.path.realpath(__file__))
    import zipfile
    with zipfile.ZipFile(os.path.join(out,'script.zip'), 'w', compression=zipfile.ZIP_DEFLATED) as new_zip:
        for f in ['trainAE.py','net.py','updaterAE.py','consts.py','losses.py','arguments.py','convert.py']:
            new_zip.write(os.path.join(rundir,f),arcname=f)

    ## visualisation
    vis_folder = os.path.join(out, "vis")
    if not os.path.exists(vis_folder):
        os.makedirs(vis_folder)
#    trainer.extend(visualize( (enc_x, enc_y, dec_y), vis_folder, test_A_iter, test_B_iter),trigger=(1, 'epoch'))
    trainer.extend(VisEvaluator({"main":test_A_iter, "testB":test_B_iter}, {"enc_x":enc_x, "enc_y":enc_y,"dec_x":dec_x,"dec_y":dec_y},
            params={'vis_out': vis_folder, 'single_encoder': args.single_encoder}, device=args.gpu[0]),trigger=vis_interval)

    # Run the training
    trainer.run()
Example #10
0
gamma = .5
train_LSTM = True
train_LSTM_using_cached_features = False
train_lstm_prob = .5
train_dis = True
image_save_interval = 200000
model_save_interval = image_save_interval
out_image_row_num = 7
out_image_col_num = 14
if train_LSTM:
    BATCH_SIZE /= seq_length
normer = args.size * args.size * 3 * 60
image_path = ['../../../trajectories/al5d']
np.random.seed(1241)
image_size = args.size
enc_model = [net.Encoder(density=8, size=image_size, latent_size=latent_size)]
gen_model = [
    net.Generator(density=8, size=image_size, latent_size=latent_size)
]
dis_model = [net.Discriminator(density=8, size=image_size)]
for i in range(num_gpus - 1):
    enc_model.append(copy.deepcopy(enc_model[0]))
    gen_model.append(copy.deepcopy(gen_model[0]))
    dis_model.append(copy.deepcopy(dis_model[0]))

enc_dis_model = net.Encoder(density=8,
                            size=image_size,
                            latent_size=latent_size)
gen_dis_model = net.Generator(density=8,
                              size=image_size,
                              latent_size=latent_size)
    if args.ch != len(dataset[0][0]):
        print("number of input channels is different during training.")
    print("Input channels {}, Output channels {}".format(args.ch, args.out_ch))

    ## load generator models
    if "enc" in args.model_gen:
        if (args.gen_pretrained_encoder and args.gen_pretrained_lr_ratio == 0):
            if "resnet" in args.gen_pretrained_encoder:
                pretrained = L.ResNet50Layers()
                print("Pretrained ResNet model loaded.")
            else:
                pretrained = L.VGG16Layers()
                print("Pretrained VGG model loaded.")
            if args.gpu >= 0:
                pretrained.to_gpu()
            enc = net.Encoder(args, pretrained)
        else:
            enc = net.Encoder(args)
        print('Loading {:s}..'.format(args.model_gen))
        serializers.load_npz(args.model_gen, enc)
        dec = net.Decoder(args)
        modelfn = args.model_gen.replace('enc_x', 'dec_y')
        modelfn = modelfn.replace('enc_y', 'dec_x')
        print('Loading {:s}..'.format(modelfn))
        serializers.load_npz(modelfn, dec)
        if args.gpu >= 0:
            enc.to_gpu()
            dec.to_gpu()
        xp = enc.xp
        is_AE = True
    elif "gen" in args.model_gen:
Example #12
0
def main():
    args = arguments()
    out = os.path.join(args.out, dt.now().strftime('%m%d_%H%M'))
    print(args)
    print("\nresults are saved under: ", out)
    save_args(args, out)

    if args.imgtype == "dcm":
        from dataset_dicom import Dataset as Dataset
    else:
        from dataset_jpg import DatasetOutMem as Dataset

    # CUDA
    if not chainer.cuda.available:
        print("CUDA required")
        exit()
    if len(args.gpu) == 1 and args.gpu[0] >= 0:
        chainer.cuda.get_device_from_id(args.gpu[0]).use()
#        cuda.cupy.cuda.set_allocator(cuda.cupy.cuda.MemoryPool().malloc)

# Enable autotuner of cuDNN
    chainer.config.autotune = True
    chainer.config.dtype = dtypes[args.dtype]
    chainer.print_runtime_info()
    # Turn off type check
    #    chainer.config.type_check = False
    #    print('Chainer version: ', chainer.__version__)
    #    print('GPU availability:', chainer.cuda.available)
    #    print('cuDNN availablility:', chainer.cuda.cudnn_enabled)

    ## dataset iterator
    print("Setting up data iterators...")
    train_A_dataset = Dataset(path=os.path.join(args.root, 'trainA'),
                              args=args,
                              random=args.random_translate,
                              forceSpacing=0)
    train_B_dataset = Dataset(path=os.path.join(args.root, 'trainB'),
                              args=args,
                              random=args.random_translate,
                              forceSpacing=args.forceSpacing)
    test_A_dataset = Dataset(path=os.path.join(args.root, 'testA'),
                             args=args,
                             random=0,
                             forceSpacing=0)
    test_B_dataset = Dataset(path=os.path.join(args.root, 'testB'),
                             args=args,
                             random=0,
                             forceSpacing=args.forceSpacing)

    args.ch = train_A_dataset.ch
    args.out_ch = train_B_dataset.ch
    print("channels in A {}, channels in B {}".format(args.ch, args.out_ch))

    test_A_iter = chainer.iterators.SerialIterator(test_A_dataset,
                                                   args.nvis_A,
                                                   shuffle=False)
    test_B_iter = chainer.iterators.SerialIterator(test_B_dataset,
                                                   args.nvis_B,
                                                   shuffle=False)

    if args.batch_size > 1:
        train_A_iter = chainer.iterators.MultiprocessIterator(train_A_dataset,
                                                              args.batch_size,
                                                              n_processes=3)
        train_B_iter = chainer.iterators.MultiprocessIterator(train_B_dataset,
                                                              args.batch_size,
                                                              n_processes=3)
    else:
        train_A_iter = chainer.iterators.SerialIterator(
            train_A_dataset, args.batch_size)
        train_B_iter = chainer.iterators.SerialIterator(
            train_B_dataset, args.batch_size)

    # setup models
    enc_x = net.Encoder(args)
    enc_y = enc_x if args.single_encoder else net.Encoder(args)
    dec_x = net.Decoder(args)
    dec_y = net.Decoder(args)
    dis_x = net.Discriminator(args)
    dis_y = net.Discriminator(args)
    dis_z = net.Discriminator(
        args) if args.lambda_dis_z > 0 else chainer.links.Linear(1, 1)
    models = {
        'enc_x': enc_x,
        'dec_x': dec_x,
        'enc_y': enc_y,
        'dec_y': dec_y,
        'dis_x': dis_x,
        'dis_y': dis_y,
        'dis_z': dis_z
    }

    ## load learnt models
    if args.load_models:
        for e in models:
            m = args.load_models.replace('enc_x', e)
            try:
                serializers.load_npz(m, models[e])
                print('model loaded: {}'.format(m))
            except:
                print("couldn't load {}".format(m))
                pass

    # select GPU
    if len(args.gpu) == 1:
        for e in models:
            models[e].to_gpu()
        print('using gpu {}, cuDNN {}'.format(args.gpu,
                                              chainer.cuda.cudnn_enabled))
    else:
        print("mandatory GPU use: currently only a single GPU can be used")
        exit()

    # Setup optimisers
    def make_optimizer(model, lr, opttype='Adam'):
        #        eps = 1e-5 if args.dtype==np.float16 else 1e-8
        optimizer = optim[opttype](lr)
        #from profiled_optimizer import create_marked_profile_optimizer
        #        optimizer = create_marked_profile_optimizer(optim[opttype](lr), sync=True, sync_level=2)
        if args.weight_decay > 0:
            if opttype in ['Adam', 'AdaBound', 'Eve']:
                optimizer.weight_decay_rate = args.weight_decay
            else:
                if args.weight_decay_norm == 'l2':
                    optimizer.add_hook(
                        chainer.optimizer.WeightDecay(args.weight_decay))
                else:
                    optimizer.add_hook(
                        chainer.optimizer_hooks.Lasso(args.weight_decay))
        optimizer.setup(model)
        return optimizer

    opt_enc_x = make_optimizer(enc_x, args.learning_rate_g, args.optimizer)
    opt_dec_x = make_optimizer(dec_x, args.learning_rate_g, args.optimizer)
    opt_enc_y = make_optimizer(enc_y, args.learning_rate_g, args.optimizer)
    opt_dec_y = make_optimizer(dec_y, args.learning_rate_g, args.optimizer)
    opt_x = make_optimizer(dis_x, args.learning_rate_d, args.optimizer)
    opt_y = make_optimizer(dis_y, args.learning_rate_d, args.optimizer)
    opt_z = make_optimizer(dis_z, args.learning_rate_d, args.optimizer)
    optimizers = {
        'opt_enc_x': opt_enc_x,
        'opt_dec_x': opt_dec_x,
        'opt_enc_y': opt_enc_y,
        'opt_dec_y': opt_dec_y,
        'opt_x': opt_x,
        'opt_y': opt_y,
        'opt_z': opt_z
    }
    if args.load_optimizer:
        for e in optimizers:
            try:
                m = args.load_models.replace('enc_x', e)
                serializers.load_npz(m, optimizers[e])
                print('optimiser loaded: {}'.format(m))
            except:
                print("couldn't load {}".format(m))
                pass

    # Set up an updater: TODO: multi gpu updater
    print("Preparing updater...")
    updater = Updater(
        models=(enc_x, dec_x, enc_y, dec_y, dis_x, dis_y, dis_z),
        iterator={
            'main': train_A_iter,
            'train_B': train_B_iter,
        },
        optimizer=optimizers,
        #        converter=convert.ConcatWithAsyncTransfer(),
        device=args.gpu[0],
        params={'args': args})

    if args.snapinterval < 0:
        args.snapinterval = args.lrdecay_start + args.lrdecay_period
    log_interval = (200, 'iteration')
    model_save_interval = (args.snapinterval, 'epoch')
    plot_interval = (500, 'iteration')

    # Set up a trainer
    print("Preparing trainer...")
    if args.iteration:
        stop_trigger = (args.iteration, 'iteration')
    else:
        stop_trigger = (args.lrdecay_start + args.lrdecay_period, 'epoch')
    trainer = training.Trainer(updater, stop_trigger, out=out)
    for e in models:
        trainer.extend(extensions.snapshot_object(models[e],
                                                  e + '{.updater.epoch}.npz'),
                       trigger=model_save_interval)
#        trainer.extend(extensions.ParameterStatistics(models[e]))   ## very slow
    for e in optimizers:
        trainer.extend(extensions.snapshot_object(optimizers[e],
                                                  e + '{.updater.epoch}.npz'),
                       trigger=model_save_interval)

    log_keys = ['epoch', 'iteration', 'lr']
    log_keys_cycle = [
        'opt_enc_x/loss_cycle', 'opt_enc_y/loss_cycle', 'opt_dec_x/loss_cycle',
        'opt_dec_y/loss_cycle', 'myval/cycle_x_l1', 'myval/cycle_y_l1'
    ]
    log_keys_d = [
        'opt_x/loss_real', 'opt_x/loss_fake', 'opt_y/loss_real',
        'opt_y/loss_fake', 'opt_z/loss_x', 'opt_z/loss_y'
    ]
    log_keys_adv = [
        'opt_enc_y/loss_adv', 'opt_dec_y/loss_adv', 'opt_enc_x/loss_adv',
        'opt_dec_x/loss_adv'
    ]
    log_keys.extend(
        ['opt_enc_x/loss_reg', 'opt_enc_y/loss_reg', 'opt_dec_y/loss_tv'])
    if args.lambda_air > 0:
        log_keys.extend(['opt_dec_x/loss_air', 'opt_dec_y/loss_air'])
    if args.lambda_grad > 0:
        log_keys.extend(['opt_dec_x/loss_grad', 'opt_dec_y/loss_grad'])
    if args.lambda_identity_x > 0:
        log_keys.extend(['opt_dec_x/loss_id', 'opt_dec_y/loss_id'])
    if args.dis_reg_weighting > 0:
        log_keys_d.extend(
            ['opt_x/loss_reg', 'opt_y/loss_reg', 'opt_z/loss_reg'])
    if args.dis_wgan:
        log_keys_d.extend(['opt_x/loss_gp', 'opt_y/loss_gp', 'opt_z/loss_gp'])

    log_keys_all = log_keys + log_keys_d + log_keys_adv + log_keys_cycle
    trainer.extend(
        extensions.LogReport(keys=log_keys_all, trigger=log_interval))
    trainer.extend(extensions.PrintReport(log_keys_all), trigger=log_interval)
    trainer.extend(extensions.ProgressBar(update_interval=20))
    trainer.extend(extensions.observe_lr(optimizer_name='opt_enc_x'),
                   trigger=log_interval)
    # learning rate scheduling
    decay_start_iter = len(train_A_dataset) * args.lrdecay_start
    decay_end_iter = len(train_A_dataset) * (args.lrdecay_start +
                                             args.lrdecay_period)
    for e in [opt_enc_x, opt_enc_y, opt_dec_x, opt_dec_y]:
        trainer.extend(
            extensions.LinearShift('alpha', (args.learning_rate_g, 0),
                                   (decay_start_iter, decay_end_iter),
                                   optimizer=e))
    for e in [opt_x, opt_y, opt_z]:
        trainer.extend(
            extensions.LinearShift('alpha', (args.learning_rate_d, 0),
                                   (decay_start_iter, decay_end_iter),
                                   optimizer=e))
    ## dump graph
    if args.report_start < 1:
        if args.lambda_tv > 0:
            trainer.extend(
                extensions.dump_graph('opt_dec_y/loss_tv', out_name='dec.dot'))
        if args.lambda_reg > 0:
            trainer.extend(
                extensions.dump_graph('opt_enc_x/loss_reg',
                                      out_name='enc.dot'))
        trainer.extend(
            extensions.dump_graph('opt_x/loss_fake', out_name='dis.dot'))

    # ChainerUI


#    trainer.extend(CommandsExtension())

    if extensions.PlotReport.available():
        trainer.extend(
            extensions.PlotReport(log_keys[3:],
                                  'iteration',
                                  trigger=plot_interval,
                                  file_name='loss.png'))
        trainer.extend(
            extensions.PlotReport(log_keys_d,
                                  'iteration',
                                  trigger=plot_interval,
                                  file_name='loss_d.png'))
        trainer.extend(
            extensions.PlotReport(log_keys_adv,
                                  'iteration',
                                  trigger=plot_interval,
                                  file_name='loss_adv.png'))
        trainer.extend(
            extensions.PlotReport(log_keys_cycle,
                                  'iteration',
                                  trigger=plot_interval,
                                  file_name='loss_cyc.png'))

    ## visualisation
    vis_folder = os.path.join(out, "vis")
    os.makedirs(vis_folder, exist_ok=True)
    if not args.vis_freq:
        args.vis_freq = len(train_A_dataset) // 2
    s = [k for k in range(args.num_slices)
         ] if args.num_slices > 0 and args.imgtype == "dcm" else None
    trainer.extend(VisEvaluator({
        "testA": test_A_iter,
        "testB": test_B_iter
    }, {
        "enc_x": enc_x,
        "enc_y": enc_y,
        "dec_x": dec_x,
        "dec_y": dec_y
    },
                                params={
                                    'vis_out': vis_folder,
                                    'slice': s
                                },
                                device=args.gpu[0]),
                   trigger=(args.vis_freq, 'iteration'))

    ## output filenames of training dataset
    with open(os.path.join(out, 'trainA.txt'), 'w') as output:
        for f in train_A_dataset.names:
            output.writelines("\n".join(f))
            output.writelines("\n")
    with open(os.path.join(out, 'trainB.txt'), 'w') as output:
        for f in train_B_dataset.names:
            output.writelines("\n".join(f))
            output.writelines("\n")

    # archive the scripts
    rundir = os.path.dirname(os.path.realpath(__file__))
    import zipfile
    with zipfile.ZipFile(os.path.join(out, 'script.zip'),
                         'w',
                         compression=zipfile.ZIP_DEFLATED) as new_zip:
        for f in [
                'train.py', 'net.py', 'updater.py', 'consts.py', 'losses.py',
                'arguments.py', 'convert.py'
        ]:
            new_zip.write(os.path.join(rundir, f), arcname=f)

    # Run the training
    trainer.run()
Example #13
0
                                           shuffle=True,
                                           **kwargs)
val_loader = torch.utils.data.DataLoader(test,
                                         batch_size=mb_size,
                                         shuffle=True,
                                         **kwargs)
test_loader = torch.utils.data.DataLoader(test,
                                          batch_size=mb_size,
                                          shuffle=False,
                                          **kwargs)

contrastiveloss = contrastive.ContrastiveLoss()
#KLloss = contrastive.KL()

#enc = net.VariationalEncoder(dim=z_dim)
enc = net.Encoder(dim=z_dim)
#dec = net.Decoder(output_dim=(28, 28))
dec = net.Decoder(dim=z_dim)

if use_cuda:
    enc.cuda()
    dec.cuda()


def reset_grad():
    enc.zero_grad()
    dec.zero_grad()


def plot(samples):
    fig = plt.figure(figsize=(4, 4))
Example #14
0
def MVAE(X, AlgPara, NNPara, gpu=-1):
    # check errors and set default values
    I, J, M = X.shape
    N = M
    if N > I:
        sys.stderr.write(
            'The input spectrogram might be wrong. The size of it must be (freq x frame x ch).\n'
        )

    W = np.zeros((I, M, N), dtype=np.complex)
    for i in range(I):
        W[i, :, :] = np.eye(N)

    # Parameter for ILRMA
    if AlgPara['nb'] is None:
        AlgPara['nb'] = np.ceil(J / 10)
    L = AlgPara['nb']
    T = np.maximum(np.random.rand(I, L, N), epsi)
    V = np.maximum(np.random.rand(L, J, N), epsi)

    R = np.zeros((I, J, N))  # variance matrix
    Y = np.zeros((I, J, N), dtype=np.complex)
    for i in range(0, I):
        Y[i, :, :] = (W[i, :, :] @ X[i, :, :].T).T
    P = np.maximum(np.abs(Y)**2, epsi)  # power spectrogram

    # ILRMA
    Y, W, R, P = ilrma(X, W, R, P, T, V, AlgPara['it0'], AlgPara['norm'])

    ####  CVAE ####
    # load trained networks
    n_freq = I - 1
    encoder = net.Encoder(n_freq, NNPara['n_src'])
    decoder = net.Decoder(n_freq, NNPara['n_src'])

    checkpoint = torch.load(NNPara['model_path'])
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    decoder.load_state_dict(checkpoint['decoder_state_dict'])

    if gpu >= 0:
        device = torch.device("cuda:{}".format(gpu))
        encoder.cuda(device)
        decoder.cuda(device)
    else:
        device = torch.device("cpu")

    Q = np.zeros((N, I, J))  # estimated variance matrix
    P = P.transpose(2, 0, 1)
    R = R.transpose(2, 0, 1)

    # initial z and l
    Y_abs = abs(Y).astype(np.float32).transpose(2, 0, 1)
    gv = np.mean(np.power(Y_abs[:, 1:, :], 2), axis=(1, 2), keepdims=True)
    Y_abs_norm = Y_abs / np.sqrt(gv)
    eps = np.ones(Y_abs_norm.shape) * epsi
    Y_abs_array_norm = np.maximum(Y_abs_norm, eps)[:, None]

    zs, ls, models, optims = [], [], [], []
    for n in range(N):
        y_abs = torch.from_numpy(
            np.asarray(Y_abs_array_norm[n, None, :, 1:, :],
                       dtype="float32")).to(device)
        label = torch.from_numpy(
            np.ones((1, NNPara['n_src']), dtype="float32") /
            NNPara['n_src']).to(device)
        z = encoder(y_abs, label)[0]
        zs.append(z)
        ls.append(label)
        Q[n, 1:, :] = np.squeeze(np.exp(
            decoder(z, label).detach().to("cpu").numpy()),
                                 axis=1)

    Q = np.maximum(Q, epsi)
    gv = np.mean(np.divide(P[:, 1:, :], Q[:, 1:, :]),
                 axis=(1, 2),
                 keepdims=True)
    Rhat = np.multiply(Q, gv)
    Rhat[:, 0, :] = R[:, 0, :]
    R = Rhat

    # Model construction
    for para in decoder.parameters():
        para.requires_grad = False

    for n in range(N):
        z_para = torch.nn.Parameter(zs[n].type(torch.float),
                                    requires_grad=True)
        l_para = torch.nn.Parameter(ls[n].type(torch.float),
                                    requires_grad=True)
        src_model = net.SourceModel(decoder, z_para, l_para)
        if gpu >= 0:
            src_model.cuda(device)
        optimizer = torch.optim.Adam(src_model.parameters(), lr=0.01)
        models.append(src_model)
        optims.append(optimizer)

    # Algorithm for MVAE
    # Iterative update
    for it in range(AlgPara['it1']):
        Y_abs_array_norm = Y_abs / np.sqrt(gv)
        for n in range(N):
            y_abs = torch.from_numpy(
                np.asarray(Y_abs_array_norm[n, None, None, 1:, :],
                           dtype="float32")).to(device)
            for iz in range(100):
                optims[n].zero_grad()
                loss = models[n].loss(y_abs)
                loss.backward()
                optims[n].step()
            Q[n, 1:, :] = models[n].get_power_spec(cpu=True)
        Q = np.maximum(Q, epsi)
        gv = np.mean(np.divide(P[:, 1:, :], Q[:, 1:, :]),
                     axis=(1, 2),
                     keepdims=True)
        Rhat = np.multiply(Q, gv)
        Rhat[:, 0, :] = R[:, 0, :]
        R = Rhat.transpose(1, 2, 0)

        # update W
        W = update_w(X, R, W)
        Y = X @ W.conj()
        Y_abs = np.abs(Y)
        Y_pow = np.power(Y_abs, 2)
        P = np.maximum(Y_pow, epsi)

        if AlgPara['norm']:
            W, R, P = local_normalize(W, R, P, I, J)

        Y_abs = Y_abs.transpose(2, 0, 1)
        P = P.transpose(2, 0, 1)
        R = R.transpose(2, 0, 1)

    return Y
Example #15
0
    ]
    label_paths = ["{}{}/label.npy".format(data_root, f) for f in src_folders]
    n_src = len(src_folders)

    src_data = [sorted(os.listdir(p)) for p in data_paths]
    n_src_data = [len(d) for d in src_data]
    src_batch_size = [math.floor(n) // N_ITER for n in n_src_data]
    labels = [np.load(p) for p in label_paths]

    # =============== Set model ==============
    # Set up model and optimizer
    x_tmp = np.load(data_paths[0] + src_data[0][0])
    n_freq = x_tmp.shape[0] - 1
    del x_tmp

    encoder = net.Encoder(n_freq, n_src)
    decoder = net.Decoder(n_freq, n_src)
    cvae = net.CVAE(encoder, decoder)

    n_para_enc = sum(p.numel() for p in encoder.parameters())
    n_para_dec = sum(p.numel() for p in decoder.parameters())

    if config.gpu >= 0:
        device = torch.device("cuda:{}".format(config.gpu))
        cvae.cuda(device)
    else:
        device = torch.device("cpu")

    optimizer = torch.optim.Adam(cvae.parameters(), lr=config.lrate)

    # load pretrained model
Example #16
0
def do_train(rawdata, charcounts, maxlens, unique_onehotvals):
    train_f_labeled = 0.2
    n_batches = 2800
    mb_size = 128
    lr = 2.0e-4
    momentum = 0.5
    cnt = 0
    latent_dim = 32  # 24#
    recurrent_hidden_size = 24

    epoch_len = 8
    max_veclen = 0.0
    patience = 12 * epoch_len
    patience_duration = 0

    input_dict = {}
    input_dict['discrete'] = discrete_cols
    input_dict['continuous'] = continuous_cols

    input_dict['onehot'] = {}
    for k in onehot_cols:
        dim = int(np.ceil(np.log(len(unique_onehotvals[k])) / np.log(2.0)))
        input_dict['onehot'][k] = dim

    if len(charcounts) > 0:
        text_dim = int(np.ceil(np.log(len(charcounts)) / np.log(2.0)))
        input_dict['text'] = {t: text_dim for t in text_cols}
    else:
        text_dim = 0
        input_dict['text'] = {}

    #data = Dataseq(rawdata, charcounts, input_dict, unique_onehotvals, maxlens)
    #data_idx = np.arange(data.__len__())
    data_idx = np.arange(rawdata.shape[0])
    np.random.shuffle(data_idx)
    n_folds = 6
    fold_size = 1.0 * rawdata.shape[0] / n_folds  #data.__len__() / n_folds
    folds = [
        data_idx[int(i * fold_size):int((i + 1) * fold_size)] for i in range(6)
    ]

    fold_groups = {}
    fold_groups[0] = {'train': [0, 1, 2, 3], 'val': [4]}
    fold_groups[1] = {'train': [1, 2, 3, 4], 'val': [0]}
    fold_groups[2] = {'train': [0, 2, 3, 4], 'val': [1]}
    fold_groups[3] = {'train': [0, 1, 3, 4], 'val': [2]}
    fold_groups[4] = {'train': [0, 1, 2, 4], 'val': [3]}

    for fold in range(1):

        train_idx = np.array(
            list(
                itertools.chain.from_iterable(
                    [folds[i] for i in fold_groups[fold]['train']])))
        val_idx = np.array(
            list(
                itertools.chain.from_iterable(
                    [folds[i] for i in fold_groups[fold]['val']])))

        np.random.shuffle(train_idx)
        train_labeled_idx = train_idx[0:int(train_f_labeled * len(train_idx))]
        train_unlabed_idx = train_idx[int(train_f_labeled * len(train_idx)):]

        data = Dataseq(rawdata, charcounts, input_dict, unique_onehotvals,
                       maxlens)
        train = Subset(data, train_idx)
        val = Subset(data, val_idx)

        kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
        train_iter = torch.utils.data.DataLoader(train,
                                                 batch_size=int(mb_size / 1),
                                                 shuffle=True,
                                                 **kwargs)
        train_iter_unshuffled = torch.utils.data.DataLoader(train,
                                                            batch_size=mb_size,
                                                            shuffle=False,
                                                            **kwargs)
        val_iter = torch.utils.data.DataLoader(val,
                                               batch_size=mb_size,
                                               shuffle=False,
                                               **kwargs)

        embeddings = {}
        reverse_embeddings = {}
        onehot_embedding_weights = {}
        for k in onehot_cols:
            dim = input_dict['onehot'][k]
            onehot_embedding_weights[k] = net.get_embedding_weight(
                len(unique_onehotvals[k]), dim, use_cuda=use_cuda)
            embeddings[k] = nn.Embedding(len(unique_onehotvals[k]),
                                         dim,
                                         _weight=onehot_embedding_weights[k])
            reverse_embeddings[k] = net.EmbeddingToIndex(
                len(unique_onehotvals[k]),
                dim,
                _weight=onehot_embedding_weights[k])

        if text_dim > 0:
            text_embedding_weights = net.get_embedding_weight(
                len(charcounts) + 1, text_dim, use_cuda=use_cuda)
            text_embedding = nn.Embedding(len(charcounts) + 1,
                                          text_dim,
                                          _weight=text_embedding_weights)
            text_embeddingtoindex = net.EmbeddingToIndex(
                len(charcounts) + 1, text_dim, _weight=text_embedding_weights)
            for k in text_cols:
                embeddings[k] = text_embedding
                reverse_embeddings[k] = text_embeddingtoindex

        enc = net.Encoder(input_dict,
                          dim=latent_dim,
                          recurrent_hidden_size=recurrent_hidden_size)
        dec = net.Decoder(input_dict,
                          maxlens,
                          dim=latent_dim,
                          recurrent_hidden_size=recurrent_hidden_size)

        if use_cuda:
            embeddings = {k: embeddings[k].cuda() for k in embeddings.keys()}
            enc.cuda()
            dec.cuda()

        logloss = contrastive.GaussianOverlap()

        solver = optim.RMSprop(
            [p for em in embeddings.values() for p in em.parameters()] +
            [p for p in enc.parameters()] + [p for p in dec.parameters()],
            lr=lr,
            momentum=momentum)

        print('starting training')
        loss = 0.0
        loss0 = 0.0
        loss1 = 0.0
        loss2 = 0.0
        loss3 = 0.0

        logger_df = pd.DataFrame(columns=[
            'iter', 'train_loss', 'train_veclen', 'val_veclen', 'val_loss',
            'val_acc'
        ] + [t + '_correct'
             for t in to_predict] + [t + '_false' for t in to_predict])

        for it in range(n_batches):
            T = next(iter(train_iter))
            # for col, value in T.items():
            #    T[col] = torch.cat((value, value, value, value), 0)

            T, X, X2, mu, logvar, mu2, mu2d, mu_tm, logvar2, logvar2d, logvar_tm = calc_mus(
                T, embeddings, reverse_embeddings, enc, dec)
            enc_loss, enc_loss0, enc_loss1, enc_loss2, enc_loss3 = calc_losses(
                T, embeddings, mu, logvar, mu2, mu2d, mu_tm, logvar2, logvar2d,
                logvar_tm, logloss)

            enc_loss.backward()
            solver.step()

            enc.zero_grad()
            dec.zero_grad()
            for col in embeddings.keys():
                embeddings[col].zero_grad()

            loss += enc_loss.data.cpu().numpy()
            loss0 += enc_loss0.data.cpu().numpy()
            loss1 += enc_loss1.data.cpu().numpy()
            loss2 += enc_loss2.data.cpu().numpy()
            loss3 += enc_loss3.data.cpu().numpy()
            veclen = torch.mean(torch.pow(mu, 2))
            if it % epoch_len == 0:
                print(
                    it, loss / epoch_len, loss0 / epoch_len, loss1 / epoch_len,
                    loss2 / epoch_len, loss3 / epoch_len,
                    veclen.data.cpu().numpy())  # enc_loss.data.cpu().numpy(),

                n_targetvals = embeddings[to_predict[0]].weight.size(0)
                if use_cuda:
                    mu = torch.zeros(len(train), mu.size(1)).cuda()
                    logvar = torch.zeros(len(train), mu.size(1)).cuda()
                    mu2 = torch.zeros(len(train), mu.size(1)).cuda()
                    mu2d = torch.zeros(len(train), mu.size(1)).cuda()
                    mu_tm = torch.zeros((len(train), ) +
                                        mu_tm.size()[1:]).cuda()
                    logvar2 = torch.zeros(len(train), mu.size(1)).cuda()
                    logvar2d = torch.zeros(len(train), mu.size(1)).cuda()
                    logvar_tm = torch.zeros(len(train), 1 + n_targetvals,
                                            mu.size(1)).cuda()
                    train_loss = torch.zeros(len(train)).cuda()
                else:
                    mu = torch.zeros(len(train), mu.size(1))
                    logvar = torch.zeros(len(train), mu.size(1))
                    mu2 = torch.zeros(len(train), mu.size(1))
                    mu2d = torch.zeros(len(train), mu.size(1))
                    mu_tm = torch.zeros((len(train), ) + mu_tm.size()[1:])
                    logvar2 = torch.zeros(len(train), mu.size(1))
                    logvar2d = torch.zeros(len(train), mu.size(1))
                    logvar_tm = torch.zeros(len(train), 1 + n_targetvals,
                                            mu.size(1))
                    train_loss = torch.zeros(len(train))

                s = 0
                for T0 in train_iter_unshuffled:
                    e = s + T0[to_predict[0]].size(0)
                    if s == 0:
                        T = {
                            col: torch.zeros((len(train), ) + value.size()[1:],
                                             dtype=value.dtype)
                            for col, value in T0.items()
                        }

                    T0, Xsample, _, mu[s:e], logvar[s:e], mu2[s:e], mu2d[
                        s:e], mu_tm[s:e], logvar2[s:e], logvar2d[
                            s:e], logvar_tm[s:e] = calc_mus(T0,
                                                            embeddings,
                                                            reverse_embeddings,
                                                            enc,
                                                            dec,
                                                            mode='val')
                    for col, value in T0.items():
                        T[col][s:e] = T0[col]

                    n_targetvals = embeddings[to_predict[0]].weight.size(0)
                    mu_tm[s:e, 0, :] = 1.0 * mu[s:e]
                    p = torch.zeros((e - s), n_targetvals).cuda()

                    # encodings for all the possible target embedding values
                    for i in range(n_targetvals):
                        if use_cuda:
                            t = {
                                col: Xsample[col]
                                if not col in to_predict else embeddings[col](
                                    i * torch.ones_like(T0[col]).cuda())
                                for col in Xsample.keys()
                            }
                            mu_tm[s:e, i + 1, :], _ = enc(t)
                        else:
                            mu_tm[s:e, i + 1, :], _ = enc({
                                col: Xsample[col] if not col in to_predict else
                                embeddings[col](i * torch.ones_like(T0[col]))
                                for col in Xsample.keys()
                            })
                        diffsquares = torch.sqrt(
                            torch.mean(
                                torch.pow(
                                    mu_tm[s:e, 0, :] - mu_tm[s:e, i + 1, :],
                                    2), 1))
                        p[:, i] = 1.0 - torch.abs(torch.erf(diffsquares / 2.0))

                    labels = T0[to_predict[0]]
                    target = torch.zeros(e - s, n_targetvals)
                    target[torch.arange(e - s), labels] = 1
                    target = target.cuda()

                    #print(target[0:5])
                    #print(p[0:5])
                    p = p / torch.sum(p, 1).view(-1, 1).repeat(1, n_targetvals)

                    train_loss[s:e] += -torch.mean(
                        target * torch.log(torch.clamp(p, 1e-8, 1.0)) +
                        (1 - target) *
                        torch.log(torch.clamp(1 - p, 1e-8, 1.0)), 1)

                    s = e

                enc_loss, enc_loss0, enc_loss1, enc_loss3, enc_loss3 = calc_losses(
                    T,
                    embeddings,
                    mu,
                    logvar,
                    mu2,
                    mu2d,
                    mu_tm,
                    logvar2,
                    logvar2d,
                    logvar_tm,
                    logloss,
                    lookfordups=False)
                vl = torch.mean(torch.pow(mu, 2))

                print(f'train enc loss {enc_loss}')
                print(f'train veclen {vl}')
                print(f'mean train logvar {torch.mean(logvar)}')
                print(f'mean train_loss {torch.mean(train_loss)}')
                logger_df.loc[
                    int(it / epoch_len),
                    ['iter', 'train_loss', 'train_veclen', 'train_loss']] = [
                        it,
                        enc_loss.data.cpu().numpy(),
                        vl.data.cpu().numpy(),
                        torch.mean(train_loss).data.cpu().numpy()
                    ]

                if use_cuda:
                    mu = torch.zeros(len(val), mu.size(1)).cuda()
                    logvar = torch.zeros(len(val), mu.size(1)).cuda()
                    mu2 = torch.zeros(len(val), mu.size(1)).cuda()
                    mu2d = torch.zeros(len(val), mu.size(1)).cuda()
                    n_targetvals = embeddings[to_predict[0]].weight.size(0)
                    mu_tm = torch.zeros(len(val), 1 + n_targetvals,
                                        mu.size(1)).cuda()
                    val_loss = torch.zeros(len(val)).cuda()
                    val_accuracy = torch.zeros(len(val)).cuda()
                else:
                    mu = torch.zeros(len(val), mu.size(1))
                    logvar = torch.zeros(len(val), mu.size(1))
                    mu2 = torch.zeros(len(val), mu.size(1))
                    mu2d = torch.zeros(len(val), mu.size(1))
                    n_targetvals = embeddings[to_predict[0]].weight.size(0)
                    mu_tm = torch.zeros(len(val), 1 + n_targetvals, mu.size(1))
                    val_loss = torch.zeros(len(val))
                    val_accuracy = torch.zeros(len(val))

                s = 0
                targets = {}
                for T0 in val_iter:
                    e = s + T0[to_predict[0]].size(0)
                    print(s, e)

                    if s == 0:
                        correct = {
                            col: np.zeros((len(val), ) + v.size()[1:])
                            for col, v in T0.items()
                        }
                        actual = {
                            col: np.zeros((len(val), ) + v.size()[1:])
                            for col, v in T0.items()
                        }

                    Xsample = {}
                    for col, tt in T0.items():
                        if use_cuda:
                            tt = Variable(tt).cuda()
                        else:
                            tt = Variable(tt)

                        if col in embeddings.keys():
                            Xsample[col] = embeddings[col](tt)
                        else:
                            Xsample[col] = tt.float()

                        if col in to_predict:
                            targets[col] = tt
                            Xsample[col] = 0.0 * Xsample[col]

                    mu[s:e], logvar[s:e] = enc(Xsample)

                    X2sample = dec(mu[s:e])
                    T2sample = discretize(X2sample, embeddings, maxlens)

                    mu2[s:e], _ = enc(X2sample)

                    X2dsample = {
                        col: (1.0 * tt).detach()
                        for col, tt in X2sample.items()
                    }
                    for col in continuous_cols:
                        if col in to_predict:
                            correct[col][s:e] = np.abs(
                                X2sample[col].data.cpu().numpy().reshape(-1) -
                                targets[col].data.cpu().numpy().reshape(-1))
                            actual[col][s:e] = targets[col].data.cpu().numpy(
                            ).reshape(-1)
                        else:
                            correct[col][s:e] = np.abs(
                                X2sample[col].data.cpu().numpy().reshape(-1) -
                                T0[col].data.cpu().numpy().reshape(-1))
                            actual[col][s:e] = T0[col].data.cpu().numpy(
                            ).reshape(-1)

                    for col, embedding in embeddings.items():
                        # T2[col] = reverse_embeddings[col](X2sample[col])
                        X2dsample[col] = embeddings[col](
                            T2sample[col].detach())

                        if col in to_predict:
                            correct[col][s:e] = np.abs(T2sample[col].data.cpu(
                            ).numpy() == targets[col].data.cpu().numpy())
                            actual[col][s:e] = targets[col].data.cpu().numpy(
                            ).reshape(-1)
                        else:
                            correct[col][s:e] = np.abs(T2sample[col].data.cpu(
                            ).numpy() == T0[col].data.cpu().numpy())
                            actual[col][s:e] = T0[col].data.cpu().numpy(
                            ).reshape(-1)

                    mu2d[s:e], _ = enc(X2dsample)
                    '''
                    calculate target predictions for validation data
                    '''

                    n_targetvals = embeddings[to_predict[0]].weight.size(0)
                    mu_tm[s:e, 0, :] = 1.0 * mu[s:e]
                    if use_cuda:
                        p = torch.zeros((e - s), n_targetvals).cuda()
                    else:
                        p = torch.zeros((e - s), n_targetvals)

                    # generate encodings for all the possible target embedding values
                    for i in range(n_targetvals):
                        if use_cuda:
                            t = {
                                col: Xsample[col]
                                if not col in to_predict else embeddings[col](
                                    i * torch.ones_like(T0[col]).cuda())
                                for col in Xsample.keys()
                            }
                            mu_tm[s:e, i + 1, :], _ = enc(t)
                        else:
                            mu_tm[s:e, i + 1, :], _ = enc({
                                col: Xsample[col] if not col in to_predict else
                                embeddings[col](i * torch.ones_like(T0[col]))
                                for col in Xsample.keys()
                            })
                        diffsquares = torch.sqrt(
                            torch.mean(
                                torch.pow(
                                    mu_tm[s:e, 0, :] - mu_tm[s:e, i + 1, :],
                                    2), 1))
                        p[:, i] = 1.0 - torch.abs(torch.erf(diffsquares / 2.0))

                        #print(mu_tm[s:s+5, i + 1, 0:5])
                        print(diffsquares[0:5])

                    labels = T0[to_predict[0]]
                    target = torch.zeros(e - s, n_targetvals)
                    target[torch.arange(e - s), labels] = 1
                    if use_cuda:
                        target = target.cuda()
                        labels = labels.cuda()

                    p = p / torch.sum(p, 1).view(-1, 1).repeat(1, n_targetvals)
                    val_accuracy[s:e] = torch.eq(labels,
                                                 torch.max(p, 1)[1]).float()

                    val_loss[s:e] += -torch.mean(
                        target * torch.log(torch.clamp(p, 1e-8, 1.0)) +
                        (1 - target) *
                        torch.log(torch.clamp(1 - p, 1e-8, 1.0)), 1)

                    s = e

                vl = torch.mean(torch.pow(mu, 2))

                print(f'val veclen {vl}')
                print(f'mean es logvar {torch.mean(logvar)}')
                print(f'mean val_loss {torch.mean(val_loss)}')
                print(f'mean val_accuracy {torch.mean(val_accuracy)}')

                logger_df.loc[
                    int(it / epoch_len),
                    ['val_veclen', 'val_loss', 'val_acc']] = vl.data.cpu(
                    ).numpy(), torch.mean(val_loss).data.cpu().numpy(
                    ), torch.mean(val_accuracy).data.cpu().numpy()
                for target_col in to_predict:
                    logger_df.loc[
                        int(it / epoch_len),
                        [target_col + '_correct', target_col +
                         '_false']] = np.mean(correct[target_col]), np.mean(
                             actual[target_col] == 0)

                for col in continuous_cols:
                    # print(np.abs(T0[col].data.cpu().numpy().reshape(-1) - T2sample[col].data.cpu().numpy().reshape(-1)))
                    print(f'% {col} mae: {np.mean(correct[col])}')

                for col in onehot_cols:
                    print(
                        f'% {col} correct: {np.mean(correct[col])} {np.mean(actual[col]==0)}'
                    )

                loss = 0.0
                loss0 = 0.0
                loss1 = 0.0
                loss2 = 0.0
                loss3 = 0.0
                # print(T2.data.cpu()[0, 0:30].numpy())

        logger_df.to_csv('logger_' + str(fold) + '.csv', index=False)