Beispiel #1
0
    # self.crispyLoss = MS_SSIM()

    def forward(self, x_recon, x, mu, logvar, epoch):
        loss_MSE = self.mse_loss(x_recon, x)
        loss_KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        #loss_cripsy = self.crispyLoss(x_recon, x)

        return loss_MSE + min(
            1.0,
            float(round(epochs / 2 + 0.75)) * KLD_annealing) * loss_KLD


encoder = PictureEncoder()
decoder = PictureDecoder()
model = GeneralVae(encoder, decoder).cuda()

if data_para and torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model, [4, 5, 6, 7])

optimizer = optim.Adam(model.parameters(), lr=LR)
#sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 5, eta_min=5e-4, last_epoch=-1)

val_losses = []
train_losses = []
lossf = customLoss()


def get_batch_size(epoch):
    return 700  #min(64  + 16 * epoch, 322 )
Beispiel #2
0
                              KLD_annealing) * loss_KLD + loss_cripsy


model = None
encoder = None
decoder = None
encoder = PictureEncoder()
decoder = PictureDecoder()

checkpoint = torch.load('/homes/aclyde11/imageVAE/mixed_im_im_small/model/' +
                        'mixed_epoch_' + str(100) + '.pt',
                        map_location="cuda:0")
encoder.load_state_dict(checkpoint['encoder_state_dict'])
decoder.load_state_dict(checkpoint['decoder_state_dict'])

model = GeneralVae(encoder, decoder, rep_size=500).cuda()

binding_model = BindingAffModel(rep_size=500).cuda()

# if data_para and torch.cuda.device_count() > 1:
#     print("Let's use", torch.cuda.device_count(), "GPUs!")
#     model = nn.DataParallel(model)

#model.to(device)

#optimizer = optim.Adam(model.parameters(), lr=LR)
#optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

binding_optimizer = optim.Adam(binding_model.parameters(), lr=5e-4)
#optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, nesterov=True)
#sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 10, eta_min=1e-5, last_epoch=-1)
Beispiel #3
0
        loss_MSE = self.mse_loss(x_recon, x)
        loss_KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        loss_cripsy = self.crispyLoss(x_recon, x)

        return loss_MSE + min(1.0, float(round(epochs / 2 + 0.75)) * KLD_annealing) * loss_KLD +  loss_cripsy

model = None
encoder = None
decoder = None
if model_load is None:
    encoder = PictureEncoder()
    decoder = PictureDecoder()
else:
    encoder = torch.load(model_load['encoder'])
    decoder = torch.load(model_load['decoder'])
model = GeneralVae(encoder, decoder, rep_size=500)


if data_para and torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model)

model.to(device)

optimizer = optim.Adam(model.parameters(), lr=LR)
#optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.8, nesterov=True)
#sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 10, eta_min=0.000001, last_epoch=-1)
loss_mse = customLoss()


def get_batch_size(epoch):
Beispiel #4
0
def main():
    global best_prec1, args

    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    args.gpu = 0
    args.world_size = 1

    if args.distributed:
        args.gpu = args.local_rank
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."

    model = None
    if args.pretrained:
        print("=> using pre-trained model")
        # model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model")
        # model = models.__dict__[args.arch]()
        checkpoint = torch.load(
            '/home/aclyde11/imageVAE/im_im_small/model/epoch_67.pt',
            map_location=torch.device('cpu'))
        encoder = PictureEncoder()
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        decoder = PictureDecoder()
        decoder.load_state_dict(checkpoint['decoder_state_dict'], strict=False)
        model = GeneralVae(encoder, decoder)

    if args.sync_bn:
        import apex
        print("using apex synced BN")
        model = apex.parallel.convert_syncbn_model(model)

    model = model.cuda()

    args.lr = args.lr * float(args.batch_size * args.world_size) / 256.
    # optimizer = torch.optim.SGD(model.parameters(), args.lr,
    #                             momentum=args.momentum,
    #                             weight_decay=args.weight_decay)
    print(args.lr)
    optimizer = torch.optim.Adam(model.parameters(), args.lr)

    model, optimizer = amp.initialize(
        model,
        optimizer,
        opt_level=args.opt_level,
        keep_batchnorm_fp32=args.keep_batchnorm_fp32,
        loss_scale=args.loss_scale)

    if args.distributed:
        # By default, apex.parallel.DistributedDataParallel overlaps communication with
        # computation in the backward pass.
        # model = DDP(model)
        # delay_allreduce delays all communication to the end of the backward pass.
        model = DDP(model, delay_allreduce=True)

    criterion = customLoss()

    train_dataset = MoleLoader(smiles_lookup_train)
    val_dataset = MoleLoader(smiles_lookup_test)

    train_sampler = None
    val_sampler = None
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
        val_sampler = torch.utils.data.distributed.DistributedSampler(
            val_dataset)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             sampler=val_sampler)

    best_prec1 = 100000
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)
        if args.prof:
            break
        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        if args.local_rank == 0:
            is_best = prec1 < best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                }, is_best)
Beispiel #5
0
    encoder = PictureEncoder(rep_size=512)
    decoder = PictureDecoder(rep_size=512)

    checkpoint = None
    if args.checkpoint is not None:
        checkpoint = torch.load(args.checkpoint, map_location='cpu')
        print(
            f"Loading Checkpoint ({args.checkpoint}). Starting at epoch: {checkpoint['epoch'] + 1}."
        )
        starting_epoch = checkpoint['epoch'] + 1
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        decoder.load_state_dict(checkpoint['decoder_state_dict'])

    encoder = encoder.to(device)
    decoder = decoder.to(device)
    model = GeneralVae(encoder, decoder, rep_size=512).to(device)

    def interpolate_points(x, y, sampling):
        ln = LinearRegression()
        data = np.stack((x, y))
        data_train = np.array([0, 1]).reshape(-1, 1)
        ln.fit(data_train, data)

        return ln.predict(sampling.reshape(-1, 1)).astype(np.float32)

    times = int(args.n / args.b)
    print(
        f"Using batch size {args.b} and sampling {times} times for a total of {args.b * times} samples drawn. Saving {'images' if args.image else 'numpy array'}"
    )
    samples = []
    for i in range(times):
    X_tsne = tsne.fit_transform(X_encoded)

    # Plot images according to t-sne embedding
    print("Plotting t-SNE visualization...")
    fig, ax = plt.subplots()
    imscatter(X_tsne[:, 0], X_tsne[:, 1], imageData=data, ax=ax, zoom=0.1)

    plt.savefig('books_read.png', dpi=900)


loader = generate_data_loader(val_root, 200, int(20000))
#train(epoch)
data = None
for d, _ in loader:
    data = d
    break
data = data.cuda()
for i in range(50, 51):

    encoder = torch.load('/homes/aclyde11/imageVAE/im_im/model/encoder_epoch_' + str(i)+ '.pt')
    decoder = torch.load('/homes/aclyde11/imageVAE/im_im/model/decoder_epoch_' + str(i)+ '.pt')
    model = GeneralVae(encoder, decoder).cuda()
    #sample(i, model, data)
    sample_plot(i, model, data)
    del model