Beispiel #1
0
encoder_optimizer = torch.optim.Adam(
    params=filter(lambda p: p.requires_grad, encoder.parameters()),
    lr=encoder_lr) if fine_tune_encoder else None

decoder_sched = torch.optim.lr_scheduler.CosineAnnealingLR(decoder_optimizer,
                                                           8,
                                                           eta_min=5e-6,
                                                           last_epoch=-1)
encoder_sched = torch.optim.lr_scheduler.CosineAnnealingLR(encoder_optimizer,
                                                           8,
                                                           eta_min=5e-6,
                                                           last_epoch=-1)

train_loader = torch.utils.data.DataLoader(
    MoleLoader(smiles_lookup, num=1000000),
    batch_size=325,
    shuffle=False,
    drop_last=True,
    sampler=torch.utils.data.SubsetRandomSampler(list(range(0, data_size))),
    **kwargs)
val_loader = torch.utils.data.DataLoader(
    MoleLoader(smiles_lookup, num=1000000),
    batch_size=325,
    shuffle=False,
    drop_last=True,
    sampler=torch.utils.data.SubsetRandomSampler(list(range(0, data_size))),
    **kwargs)
criterion = nn.CrossEntropyLoss().cuda(2)

Beispiel #2
0
def apply_one_hot(ch):
    return np.array(
        map(
            lambda x: np.pad(one_hot_encoded_fn(x),
                             pad_width=[(0, 60 - len(x)), (0, 0)],
                             mode='constant',
                             constant_values=0), ch))


smiles_lookup_train = pd.read_csv("/homes/aclyde11/moses/data/train.csv")
print(smiles_lookup_train.head())
smiles_lookup_test = pd.read_csv("/homes/aclyde11/moses/data/test.csv")
print(smiles_lookup_test.head())

val_loader_food = torch.utils.data.DataLoader(MoleLoader(smiles_lookup_test),
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              drop_last=True,
                                              **kwargs)


class customLoss(nn.Module):
    def __init__(self):
        super(customLoss, self).__init__()
        self.mse_loss = nn.MSELoss(reduction="sum")

    # self.crispyLoss = MS_SSIM()

    def compute_kernel(self, x, y):
        x_size = x.shape[0]
Beispiel #3
0
    return args.batch_size


def clip_gradient(optimizer, grad_clip=1.0):
    """
    Clips gradients computed during backpropagation to avoid explosion of gradients.
    :param optimizer: optimizer with the gradients to be clipped
    :param grad_clip: clip value
    """
    for group in optimizer.param_groups:
        for param in group['params']:
            if param.grad is not None:
                param.grad.data.clamp_(-grad_clip, grad_clip)


train_data = MoleLoader(smiles_lookup_train)
# val_data = MoleLoader(smiles_lookup_test)

train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
# val_sampler =   torch.utils.data.distributed.DistributedSampler(val_data)
train_loader_food = torch.utils.data.DataLoader(
    train_data,
    batch_size=args.batch_size,
    sampler=train_sampler,
    drop_last=
    True,  #sampler=torch.utils.data.SubsetRandomSampler(indices=list(set(list(np.random.randint(0, len(train_data), size=1250000))))),
    **kwargs)
# val_loader_food = torch.utils.data.DataLoader(
#         val_data,
#         batch_size=args.batch_size, sampler=val_sampler, drop_last=True, #sampler=torch.utils.data.SubsetRandomSampler(indices=list(set(list(np.random.randint(0, len(val_data), size=10000))))),
#         **kwargs)
Beispiel #4
0
vocab = {k: v for v, k in enumerate(vocab)}
charset = {k: v for v, k in vocab.items()}

# model_load1 = {'decoder' : '/homes/aclyde11/imageVAE/combo/model/decoder1_epoch_111.pt', 'encoder':'/homes/aclyde11/imageVAE/smi_smi/model/encoder_epoch_100.pt'}
cuda = True
data_size = 1400000
torch.manual_seed(seed)
output_dir = '/homes/aclyde11/imageVAE/combo/results/'
save_files = '/homes/aclyde11/imageVAE/combo/model/'
device = torch.device("cuda" if cuda else "cpu")
kwargs = {'num_workers': 32, 'pin_memory': True} if cuda else {}

train_data = MoleLoader(pd.read_csv("/homes/aclyde11/zinc/zinc_cleaned.smi",
                                    sep=' ',
                                    header=None,
                                    engine='c',
                                    low_memory=False),
                        vocab,
                        max_len=70)
val_data = MoleLoader(pd.read_csv("/homes/aclyde11/zinc/zinc_cleaned.smi",
                                  sep=' ',
                                  header=None,
                                  engine='c',
                                  low_memory=False),
                      vocab,
                      max_len=70)

train_loader_food = torch.utils.data.DataLoader(train_data,
                                                batch_size=batch_size,
                                                shuffle=False,
                                                drop_last=True,
Beispiel #5
0
def main():
    global best_prec1, args

    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    args.gpu = 0
    args.world_size = 1

    if args.distributed:
        args.gpu = args.local_rank
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."

    model = None
    if args.pretrained:
        print("=> using pre-trained model")
        # model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model")
        # model = models.__dict__[args.arch]()
        checkpoint = torch.load(
            '/home/aclyde11/imageVAE/im_im_small/model/epoch_67.pt',
            map_location=torch.device('cpu'))
        encoder = PictureEncoder()
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        decoder = PictureDecoder()
        decoder.load_state_dict(checkpoint['decoder_state_dict'], strict=False)
        model = GeneralVae(encoder, decoder)

    if args.sync_bn:
        import apex
        print("using apex synced BN")
        model = apex.parallel.convert_syncbn_model(model)

    model = model.cuda()

    args.lr = args.lr * float(args.batch_size * args.world_size) / 256.
    # optimizer = torch.optim.SGD(model.parameters(), args.lr,
    #                             momentum=args.momentum,
    #                             weight_decay=args.weight_decay)
    print(args.lr)
    optimizer = torch.optim.Adam(model.parameters(), args.lr)

    model, optimizer = amp.initialize(
        model,
        optimizer,
        opt_level=args.opt_level,
        keep_batchnorm_fp32=args.keep_batchnorm_fp32,
        loss_scale=args.loss_scale)

    if args.distributed:
        # By default, apex.parallel.DistributedDataParallel overlaps communication with
        # computation in the backward pass.
        # model = DDP(model)
        # delay_allreduce delays all communication to the end of the backward pass.
        model = DDP(model, delay_allreduce=True)

    criterion = customLoss()

    train_dataset = MoleLoader(smiles_lookup_train)
    val_dataset = MoleLoader(smiles_lookup_test)

    train_sampler = None
    val_sampler = None
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
        val_sampler = torch.utils.data.distributed.DistributedSampler(
            val_dataset)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             sampler=val_sampler)

    best_prec1 = 100000
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)
        if args.prof:
            break
        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        if args.local_rank == 0:
            is_best = prec1 < best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                }, is_best)