Example #1
0
        self.mse_loss = nn.MSELoss(reduction="sum")

    # self.crispyLoss = MS_SSIM()

    def forward(self, x_recon, x, mu, logvar, epoch):
        loss_MSE = self.mse_loss(x_recon, x)
        loss_KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        #loss_cripsy = self.crispyLoss(x_recon, x)

        return loss_MSE + min(
            1.0,
            float(round(epochs / 2 + 0.75)) * KLD_annealing) * loss_KLD


encoder = PictureEncoder()
decoder = PictureDecoder()
model = GeneralVae(encoder, decoder).cuda()

if data_para and torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model, [4, 5, 6, 7])

optimizer = optim.Adam(model.parameters(), lr=LR)
#sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 5, eta_min=5e-4, last_epoch=-1)

val_losses = []
train_losses = []
lossf = customLoss()


def get_batch_size(epoch):
Example #2
0
class customLoss(nn.Module):
    def __init__(self):
        super(customLoss, self).__init__()
        self.mse_loss = nn.MSELoss(reduction="sum")
        self.crispyLoss = MS_SSIM()

    def forward(self, x_recon, x, mu, logvar, epoch):
        loss_MSE = self.mse_loss(x_recon, x)
        loss_KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        loss_cripsy = self.crispyLoss(x_recon, x)

        return loss_MSE + min(1.0, float(round(epochs / 2 + 0.75)) * KLD_annealing) * loss_KLD + 0 * loss_cripsy

model = None
if model_load is None:
    model = SmilesToImageModle(SmilesEncoder(embedding_width, embedding_size, ), PictureDecoder())
else:
    model = torch.load(model_load)
if load_state is not None:
    model.load_state_dict(torch.load(load_state))

if data_para and torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model)

model.to(device)

optimizer = optim.Adam(model.parameters(), lr=LR)
#optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.8, nesterov=True)
#sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 10, eta_min=0.000001, last_epoch=-1)
Example #3
0
    def forward(self, x_recon, x, mu, logvar, epoch):
        loss_MSE = self.mse_loss(x_recon, x)
        loss_KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        loss_cripsy = self.crispyLoss(x_recon, x)

        return loss_MSE + min(1.0,
                              float(round(epochs / 2 + 0.75)) *
                              KLD_annealing) * loss_KLD + loss_cripsy


model = None
encoder = None
decoder = None
encoder = PictureEncoder()
decoder = PictureDecoder()

checkpoint = torch.load('/homes/aclyde11/imageVAE/mixed_im_im_small/model/' +
                        'mixed_epoch_' + str(100) + '.pt',
                        map_location="cuda:0")
encoder.load_state_dict(checkpoint['encoder_state_dict'])
decoder.load_state_dict(checkpoint['decoder_state_dict'])

model = GeneralVae(encoder, decoder, rep_size=500).cuda()

binding_model = BindingAffModel(rep_size=500).cuda()

# if data_para and torch.cuda.device_count() > 1:
#     print("Let's use", torch.cuda.device_count(), "GPUs!")
#     model = nn.DataParallel(model)
Example #4
0
class customLoss(nn.Module):
    def __init__(self):
        super(customLoss, self).__init__()
        self.mse_loss = nn.MSELoss(reduction="sum")
        self.crispyLoss = MS_SSIM()

    def forward(self, x_recon, x, mu, logvar, epoch):
        loss_MSE = self.mse_loss(x_recon, x)
        loss_KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        loss_cripsy = self.crispyLoss(x_recon, x)

        return 1.25 * loss_MSE + min(1.0, float(round(epochs / 2 + 0.75)) * KLD_annealing) * loss_KLD + 0.9 * loss_cripsy

model = None
if model_load is None:
    model = SmilesToImageModle(SmilesEncoder(50, 50, ), PictureDecoder())
else:
    model = torch.load(model_load)
if load_state is not None:
    model.load_state_dict(torch.load(load_state))

if data_para and torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model)

model.to(device)

#optimizer = optim.Adam(model.parameters(), lr=LR)
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.8, nesterov=True)
#sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 10, eta_min=0.000001, last_epoch=-1)
loss_mse = customLoss()
Example #5
0
        sampler=torch.utils.data.SubsetRandomSampler(list(range(0,
                                                                data_size))),
        **kwargs)


model_load1 = {
    'decoder': '/homes/aclyde11/imageVAE/combo/model/decoder1_epoch_15.pt',
    'encoder': '/homes/aclyde11/imageVAE/combo/model/encoder1_epoch_15.pt'
}
model_load2 = {
    'decoder': '/homes/aclyde11/imageVAE/combo/model/decoder2_epoch_15.pt',
    'encoder': '/homes/aclyde11/imageVAE/combo/model/encoder2_epoch_15.pt'
}

encoder1 = PictureEncoder()
decoder1 = PictureDecoder()
decoder2 = MolDecoder()
encoder2 = DenseMolEncoder()
# encoder1 = torch.load(model_load1['encoder'])
# encoder2 = torch.load(model_load2['encoder'])
# decoder1 = torch.load(model_load1['decoder'])
# decoder2 = torch.load(model_load2['decoder'])

model = ComboVAE(encoder1, encoder2, decoder1, decoder2, rep_size=500).cuda()

if data_para and torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model)

optimizer = optim.Adam(model.parameters(), lr=LR)
#optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.8, nesterov=True)
Example #6
0
    kwargs = {
        'num_workers': args.workers,
        'pin_memory': True
    } if cuda else {
        'num_workers': args.workers
    }

    print("\nloading data...")
    smiles_lookup_train = pd.read_csv(f"{args.d}/train.csv")
    print(smiles_lookup_train.head())
    smiles_lookup_test = pd.read_csv(f"{args.d}/test.csv")
    print(smiles_lookup_test.head())
    print("Done.\n")

    encoder = PictureEncoder(rep_size=512)
    decoder = PictureDecoder(rep_size=512)

    checkpoint = None
    if args.checkpoint is not None:
        checkpoint = torch.load(args.checkpoint, map_location='cpu')
        print(
            f"Loading Checkpoint ({args.checkpoint}). Starting at epoch: {checkpoint['epoch'] + 1}."
        )
        starting_epoch = checkpoint['epoch'] + 1
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        decoder.load_state_dict(checkpoint['decoder_state_dict'])

    encoder = encoder.to(device)
    decoder = decoder.to(device)
    model = GeneralVae(encoder, decoder, rep_size=512).to(device)
Example #7
0
def main():
    global best_prec1, args

    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    args.gpu = 0
    args.world_size = 1

    if args.distributed:
        args.gpu = args.local_rank
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."

    model = None
    if args.pretrained:
        print("=> using pre-trained model")
        # model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model")
        # model = models.__dict__[args.arch]()
        checkpoint = torch.load(
            '/home/aclyde11/imageVAE/im_im_small/model/epoch_67.pt',
            map_location=torch.device('cpu'))
        encoder = PictureEncoder()
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        decoder = PictureDecoder()
        decoder.load_state_dict(checkpoint['decoder_state_dict'], strict=False)
        model = GeneralVae(encoder, decoder)

    if args.sync_bn:
        import apex
        print("using apex synced BN")
        model = apex.parallel.convert_syncbn_model(model)

    model = model.cuda()

    args.lr = args.lr * float(args.batch_size * args.world_size) / 256.
    # optimizer = torch.optim.SGD(model.parameters(), args.lr,
    #                             momentum=args.momentum,
    #                             weight_decay=args.weight_decay)
    print(args.lr)
    optimizer = torch.optim.Adam(model.parameters(), args.lr)

    model, optimizer = amp.initialize(
        model,
        optimizer,
        opt_level=args.opt_level,
        keep_batchnorm_fp32=args.keep_batchnorm_fp32,
        loss_scale=args.loss_scale)

    if args.distributed:
        # By default, apex.parallel.DistributedDataParallel overlaps communication with
        # computation in the backward pass.
        # model = DDP(model)
        # delay_allreduce delays all communication to the end of the backward pass.
        model = DDP(model, delay_allreduce=True)

    criterion = customLoss()

    train_dataset = MoleLoader(smiles_lookup_train)
    val_dataset = MoleLoader(smiles_lookup_test)

    train_sampler = None
    val_sampler = None
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
        val_sampler = torch.utils.data.distributed.DistributedSampler(
            val_dataset)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             sampler=val_sampler)

    best_prec1 = 100000
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)
        if args.prof:
            break
        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        if args.local_rank == 0:
            is_best = prec1 < best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                }, is_best)