Beispiel #1
0
        #loss_cripsy = self.crispyLoss(x_recon, x)

        return loss_MSE + min(
            1.0,
            float(round(epochs / 2 + 0.75)) * KLD_annealing) * loss_KLD


encoder = PictureEncoder()
decoder = PictureDecoder()
model = GeneralVae(encoder, decoder).cuda()

if data_para and torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model, [4, 5, 6, 7])

optimizer = optim.Adam(model.parameters(), lr=LR)
#sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 5, eta_min=5e-4, last_epoch=-1)

val_losses = []
train_losses = []
lossf = customLoss()


def get_batch_size(epoch):
    return 700  #min(64  + 16 * epoch, 322 )


def train(epoch, train_loader):
    with experiment.train():
        experiment.log_current_epoch(epoch)
Beispiel #2
0
def main():
    global best_prec1, args

    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    args.gpu = 0
    args.world_size = 1

    if args.distributed:
        args.gpu = args.local_rank
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."

    model = None
    if args.pretrained:
        print("=> using pre-trained model")
        # model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model")
        # model = models.__dict__[args.arch]()
        checkpoint = torch.load(
            '/home/aclyde11/imageVAE/im_im_small/model/epoch_67.pt',
            map_location=torch.device('cpu'))
        encoder = PictureEncoder()
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        decoder = PictureDecoder()
        decoder.load_state_dict(checkpoint['decoder_state_dict'], strict=False)
        model = GeneralVae(encoder, decoder)

    if args.sync_bn:
        import apex
        print("using apex synced BN")
        model = apex.parallel.convert_syncbn_model(model)

    model = model.cuda()

    args.lr = args.lr * float(args.batch_size * args.world_size) / 256.
    # optimizer = torch.optim.SGD(model.parameters(), args.lr,
    #                             momentum=args.momentum,
    #                             weight_decay=args.weight_decay)
    print(args.lr)
    optimizer = torch.optim.Adam(model.parameters(), args.lr)

    model, optimizer = amp.initialize(
        model,
        optimizer,
        opt_level=args.opt_level,
        keep_batchnorm_fp32=args.keep_batchnorm_fp32,
        loss_scale=args.loss_scale)

    if args.distributed:
        # By default, apex.parallel.DistributedDataParallel overlaps communication with
        # computation in the backward pass.
        # model = DDP(model)
        # delay_allreduce delays all communication to the end of the backward pass.
        model = DDP(model, delay_allreduce=True)

    criterion = customLoss()

    train_dataset = MoleLoader(smiles_lookup_train)
    val_dataset = MoleLoader(smiles_lookup_test)

    train_sampler = None
    val_sampler = None
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
        val_sampler = torch.utils.data.distributed.DistributedSampler(
            val_dataset)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             sampler=val_sampler)

    best_prec1 = 100000
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)
        if args.prof:
            break
        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        if args.local_rank == 0:
            is_best = prec1 < best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                }, is_best)