def main(args):
    logger, result_dir, dir_name = utils.config_backup_get_log(args, __file__)
    device = utils.get_device()
    utils.set_seed(args.seed, device)

    # load data
    pXY = data.Normal(args.dim, args.rho, device)
    model = methods.setup_method(args.method, args.dim, args.hidden,
                                 args.layers).to(device)
    optim = torch.optim.Adam(model.parameters(), lr=args.lr)

    for step in range(1, args.steps + 1):
        X, Y = pXY.draw_samples(args.N)
        XY_package = torch.cat(
            [X.repeat_interleave(X.size(0), 0),
             Y.repeat(Y.size(0), 1)], dim=1)
        optim.zero_grad()
        L = model(X, Y, XY_package)
        L.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
        optim.step()
        print('step {:4d} | '.format(step), end='')
        print('ln N: {:.2f} | I(X,Y): {:.2f} | est. I(X,Y): {:.2f}'.format(
            math.log(args.N), pXY.I(), -L.item()))

    # Final evaluation
    M = args.N
    X, Y = pXY.draw_samples(M)
    XY_package = torch.cat([X.repeat_interleave(M, 0), Y.repeat(M, 1)], dim=1)
    test_MI = {}
    model.eval()
    test_MI = -model(X, Y, XY_package).item()
    print('{:6.2f}'.format(test_MI))
    print('ln({:d}): {:.2f} | I(X,Y): {:.2f}'.format(M, math.log(M), pXY.I()))
Exemple #2
0
def main(args):
    logger, result_dir, dir_name = utils.config_backup_get_log(args, __file__)
    csv_name = os.path.join(result_dir, 'rate_distortion.csv')
    file = open(csv_name, 'w', newline='')
    csvwriter = csv.writer(file)
    csvwriter.writerow(['beta', args.beta])

    device = utils.get_device()
    utils.set_seed(args.seed, device)

    trainloader = dataset.get_trainloader(args.data, args.dataroot, args.type,
                                          args.target, args.bstrain,
                                          args.nworkers)

    import models
    model = models.get_vae(args.data, L=10).to(device)

    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           betas=(0.5, 0.999),
                           weight_decay=1e-4)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.milestones,
                                               gamma=0.1)

    chpt_name = 'betaVAE_%s_target%s_seed%s.pth' % (args.data, str(
        args.target), str(args.seed))
    chpt_name = os.path.join("./chpt", chpt_name)

    print('==> Start training ..')
    start = time.time()
    loss, distortion, rate = test(model, trainloader, logger, device,
                                  args.beta)
    csvwriter.writerow([-1, distortion, rate, distortion + rate])
    print('EPOCH %3d LOSS: %.4F, DISTORTION: %.4f, RATE: %.4f, D+R: %.4f' %
          (-1, loss, distortion, rate, distortion + rate))
    for epoch in range(args.maxepoch):
        loss, distortion, rate = train(epoch, model, trainloader, optimizer,
                                       scheduler, logger, device, args.beta)
        csvwriter.writerow([epoch, distortion, rate, distortion + rate])
        print('EPOCH %3d LOSS: %.4F, DISTORTION: %.4f, RATE: %.4f, D+R: %.4f' %
              (epoch, loss, distortion, rate, distortion + rate))

    file.close()
    end = time.time()
    hours, rem = divmod(end - start, 3600)
    minutes, seconds = divmod(rem, 60)
    print("Elapsed Time: {:0>2}:{:0>2}:{:05.2f}".format(
        int(hours), int(minutes), seconds))
    logger.write("Elapsed Time: {:0>2}:{:0>2}:{:05.2f}\n".format(
        int(hours), int(minutes), seconds))

    if args.batchout:
        f = open('temp_result.txt', 'w', newline='')
        f.close()
def main(args):
    logger, result_dir, dir_name = utils.config_backup_get_log(args,__file__)
    device = utils.get_device()
    utils.set_seed(args.seed, device)

    trainloader = dataset.get_trainloader(args.data, args.dataroot, args.target, args.bstrain, args.nworkers)
    testloader = dataset.get_testloader(args.data, args.dataroot, args.target, args.bstest, args.nworkers)
    
    import models
    encoder, decoder, discriminator = models.get_aae(args.data)
    encoder.to(device)
    decoder.to(device)
    discriminator.to(device)

    # Use binary cross-entropy loss
    adversarial_loss = torch.nn.BCELoss().to(device)
    pixelwise_loss = torch.nn.L1Loss().to(device)

    optimizer_G = torch.optim.Adam(itertools.chain(encoder.parameters(), decoder.parameters()), lr=args.lr, betas=(0.5, 0.999), weight_decay=1e-4)
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=args.lr, betas=(0.5, 0.999), weight_decay=1e-4)

    scheduler_G = optim.lr_scheduler.MultiStepLR(optimizer_G, milestones=args.milestones, gamma=0.1)   
    scheduler_D = optim.lr_scheduler.MultiStepLR(optimizer_D, milestones=args.milestones, gamma=0.1)   

    chpt_name = 'AAE_%s_target%s_seed%s.pth'%(args.data, str(args.target), str(args.seed))
    chpt_name = os.path.join("./chpt",chpt_name)

    print('==> Start training ..')   
    best_auroc = 0.
    start = time.time()
    for epoch in range(args.maxepoch):
        train(epoch, adversarial_loss, pixelwise_loss, encoder, decoder, discriminator, trainloader, optimizer_G, optimizer_D, scheduler_G, scheduler_D, logger, device)

    auroc, aupr, _ = test(encoder, decoder, testloader, device)
    print('Epoch: %4d AUROC: %.4f AUPR: %.4f'%(epoch, auroc, aupr))
    logger.write('Epoch: %4d AUROC: %.4f AUPR: %.4f \n'%(epoch, auroc, aupr))
    state = {
        'encoder': encoder.state_dict(), 
        'decoder': decoder.state_dict(), 
        'discriminator': discriminator.state_dict(), 
        'auroc': auroc, 
        'epoch': epoch}
    torch.save(state, chpt_name)

    end = time.time()
    hours, rem = divmod(end-start, 3600)
    minutes, seconds = divmod(rem, 60)
    print('AUROC... ', auroc)
    print("Elapsed Time: {:0>2}:{:0>2}:{:05.2f}".format(int(hours),int(minutes),seconds))
    logger.write("AUROC: %.8f\n"%(auroc))
    logger.write("Elapsed Time: {:0>2}:{:0>2}:{:05.2f}\n".format(int(hours),int(minutes),seconds))
Exemple #4
0
def main():
    logger, result_dir, dir_name = utils.config_backup_get_log(args, "matcher", __file__)
    device = utils.get_device()
    utils.set_seed(args.seed, device) # set random seed

    chpt_name = 'matcher_%s_%s'%(args.dataname, args.method)
    print('Checkpoint name:', chpt_name)

    # load data
    assert args.dataname in ['Haywrd','ykdelB']
    dataconfig = configs.dataconfig[args.sartype.lower()]
    traindata = dataconfig[args.dataname]['traindata']
    testdata = dataconfig[args.dataname]['testdata']
    
    trainloader = loader_grd.setup_trainloader(traindata, args)
    # testloader = loader_grd.setup_testloader(traindata, args)
    # testloader_real = loader_grd.setup_testloader(testdata, args)
    gtmat = loader_grd.load_gtmat(args.sartype, dataconfig[args.dataname]['gtmat'])
    
    # setup descriptor
    state_dict = torch.load(args.descriptor_name)
    args_load = state_dict['args']
    descriptor = model.dcsnn_full(args_load["nglobal"], 256, pretrained=True, K=args_load['moco_k'], m=args_load['moco_m'],\
         T = args_load['temp'], no_moco=args_load['no_moco'], ntrain=len(traindata)).to(device)
    descriptor.load_state_dict(state_dict['net']) # load chpt
    lf_generator = descriptor.lf_generator

    # setup matcher
    matcher = SuperGlue(sinkhorn_iterations=args.sinkhorn_iter, match_threshold=args.match_threshold).to(device)
    # optimizer = torch.optim.SGD(matcher.parameters(), lr=args.lr, weight_decay=1e-4, momentum=0.)
    optimizer = torch.optim.Adam(matcher.parameters(), lr=args.lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[15,20], gamma=0.1)

    print('==> Start training ..')   
    start = time.time()
    best_mAP = -1

    state_dict = torch.load('./chpt/matcher_Haywrd_tanta_attempt1.pth')
    args_load = state_dict['args']
    # matcher.load_state_dict(state_dict['net'])

    for epoch in range(args.maxepoch):
        train(epoch, lf_generator, matcher, trainloader, optimizer, device, logger)
    
    state = {'method': args.method, 'net': matcher.state_dict(), 'args': vars(args)}
    if result_dir is not None:
        torch.save(state, './%s/%s.pth'%(result_dir, chpt_name))
def main(args):
    logger, result_dir, dir_name = utils.config_backup_get_log(args, __file__)
    device = utils.get_device()
    utils.set_seed(args.seed, device)

    trainloader = dataset.get_trainloader(args.data, args.dataroot,
                                          args.target, args.bstrain,
                                          args.nworkers)
    testloader = dataset.get_testloader(args.data, args.dataroot, args.target,
                                        args.bstest, args.nworkers)

    import models
    model = models.get_pgn_encoder(args.data, args.dropoutp).to(device)

    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           betas=(0.5, 0.999),
                           weight_decay=1e-4)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.milestones,
                                               gamma=0.1)

    chpt_name = 'GPN_%s_target%s_seed%s.pth' % (args.data, str(
        args.target), str(args.seed))
    chpt_name = os.path.join("./chpt", chpt_name)

    print('==> Start training ..')
    start = time.time()

    for epoch in range(args.maxepoch):
        train(epoch, model, trainloader, optimizer, scheduler, logger, device)

    auroc, aupr, _ = test(model, testloader, args.mcdropoutT, device)
    print('Epoch: %4d AUROC: %.6f AUPR: %.6f' % (epoch, auroc, aupr))
    logger.write('Epoch: %4d AUROC: %.6f AUPR: %.6f \n' % (epoch, auroc, aupr))
    state = {'model': model.state_dict(), 'auroc': auroc, 'epoch': epoch}
    torch.save(state, chpt_name)

    end = time.time()
    hours, rem = divmod(end - start, 3600)
    minutes, seconds = divmod(rem, 60)
    print('AUROC... ', auroc)
    print("Elapsed Time: {:0>2}:{:0>2}:{:05.2f}".format(
        int(hours), int(minutes), seconds))
    logger.write("AUROC: %.8f\n" % (auroc))
    logger.write("Elapsed Time: {:0>2}:{:0>2}:{:05.2f}\n".format(
        int(hours), int(minutes), seconds))
Exemple #6
0
def main():
    logger, result_dir, _ = utils.config_backup_get_log(args, __file__)

    device = utils.get_device()
    utils.set_seed(args.seed, device)  # set random seed

    dataset = COVID19DataSet(root=args.datapath,
                             ctonly=args.ctonly)  # load dataset
    trainset, testset = split_dataset(dataset=dataset, logger=logger)

    if args.model.lower() in ['mobilenet']:
        net = mobilenet_v2(task='classification',
                           moco=False,
                           ctonly=args.ctonly).to(device)
    elif args.model.lower() in ['densenet']:
        net = densenet121(task='classification',
                          moco=False,
                          ctonly=args.ctonly).to(device)
    else:
        raise Exception

    criterion = torch.nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=args.lr,
                                 weight_decay=1e-3)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=25,
                                                gamma=0.1)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.bstrain,
                                              shuffle=True,
                                              num_workers=args.nworkers)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=args.bstest,
                                             shuffle=False,
                                             num_workers=args.nworkers)

    best_auroc = 0.
    print('==> Start training ..')
    start = time.time()
    for epoch in range(args.maxepoch):
        net = train(epoch, net, trainloader, criterion, optimizer, scheduler,
                    args.model, device)
        scheduler.step()
        if epoch % 5 == 0:
            auroc, aupr, f1_score, accuracy = validate(net, testloader, device)
            logger.write(
                'Epoch:%3d | AUROC: %5.4f | AUPR: %5.4f | F1_Score: %5.4f | Accuracy: %5.4f\n'
                % (epoch, auroc, aupr, f1_score, accuracy))
            if auroc > best_auroc:
                best_auroc = auroc
                best_aupr = aupr
                best_epoch = epoch
                print("save checkpoint...")
                torch.save(net.state_dict(),
                           './%s/%s.pth' % (result_dir, args.model))

    auroc, aupr, f1_score, accuracy = validate(net, testloader, device)
    logger.write(
        'Epoch:%3d | AUROC: %5.4f | AUPR: %5.4f | F1_Score: %5.4f | Accuracy: %5.4f\n'
        % (epoch, auroc, aupr, f1_score, accuracy))

    if args.batchout:
        with open('temp_result.txt', 'w') as f:
            f.write("%10.8f\n" % (best_auroc))
            f.write("%10.8f\n" % (best_aupr))
            f.write("%d" % (best_epoch))

    end = time.time()
    hours, rem = divmod(end - start, 3600)
    minutes, seconds = divmod(rem, 60)
    print("Elapsed Time: {:0>2}:{:0>2}:{:05.2f}".format(
        int(hours), int(minutes), seconds))
    logger.write("Elapsed Time: {:0>2}:{:0>2}:{:05.2f}\n".format(
        int(hours), int(minutes), seconds))
    return True
def main(args):
    logger, result_dir, dir_name = utils.config_backup_get_log(args, __file__)
    device = utils.get_device()
    utils.set_seed(args.seed, device)

    trainloader = dataset.get_trainloader(args.data, args.dataroot,
                                          args.target, args.bstrain,
                                          args.nworkers)
    testloader = dataset.get_testloader(args.data, args.dataroot, args.target,
                                        args.bstest, args.nworkers)

    import models
    encoder, generator, discriminator, discriminator_z = models.get_gpnd(
        args.data)
    encoder.to(device)
    generator.to(device)
    discriminator.to(device)
    discriminator_z.to(device)

    optimizer_G = optim.Adam(generator.parameters(),
                             lr=args.lr,
                             betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(),
                             lr=args.lr,
                             betas=(0.5, 0.999))
    optimizer_E = optim.Adam(encoder.parameters(),
                             lr=args.lr,
                             betas=(0.5, 0.999))
    optimizer_GE = optim.Adam(list(encoder.parameters()) +
                              list(generator.parameters()),
                              lr=args.lr,
                              betas=(0.5, 0.999))
    optimizer_ZD = optim.Adam(discriminator_z.parameters(),
                              lr=args.lr,
                              betas=(0.5, 0.999))

    scheduler_G = optim.lr_scheduler.MultiStepLR(optimizer_G,
                                                 milestones=args.milestones,
                                                 gamma=args.gamma)
    scheduler_D = optim.lr_scheduler.MultiStepLR(optimizer_D,
                                                 milestones=args.milestones,
                                                 gamma=args.gamma)
    scheduler_E = optim.lr_scheduler.MultiStepLR(optimizer_E,
                                                 milestones=args.milestones,
                                                 gamma=args.gamma)
    scheduler_GE = optim.lr_scheduler.MultiStepLR(optimizer_GE,
                                                  milestones=args.milestones,
                                                  gamma=args.gamma)
    scheduler_ZD = optim.lr_scheduler.MultiStepLR(optimizer_ZD,
                                                  milestones=args.milestones,
                                                  gamma=args.gamma)
    schedulers = (scheduler_G, scheduler_D, scheduler_E, scheduler_GE,
                  scheduler_ZD)
    chpt_name = 'GPND_%s_target%s_seed%s.pth' % (args.data, str(
        args.target), str(args.seed))
    chpt_name = os.path.join("./chpt", chpt_name)

    print('==> Start training ..')
    start = time.time()
    for epoch in range(args.maxepoch):
        train(epoch, encoder, generator, discriminator, discriminator_z,
              trainloader, optimizer_G, optimizer_D, optimizer_E, optimizer_GE,
              optimizer_ZD, schedulers, logger, device)
        if epoch > 79 and epoch % 20 == 0:
            auroc, aupr, _ = test(encoder, generator, testloader, device)
            print(auroc)

    auroc, aupr, _ = test(encoder, generator, testloader, device)
    print('Epoch: %4d AUROC: %.4f AUPR: %.4f' % (epoch, auroc, aupr))
    state = {
        'encoder': encoder.state_dict(),
        'generator': generator.state_dict(),
        'discriminator': discriminator.state_dict(),
        'discriminator_z': discriminator_z.state_dict(),
        'auroc': auroc,
        'epoch': epoch
    }
    torch.save(state, chpt_name)

    end = time.time()
    hours, rem = divmod(end - start, 3600)
    minutes, seconds = divmod(rem, 60)
    print('AUROC... ', auroc)
    print("Elapsed Time: {:0>2}:{:0>2}:{:05.2f}".format(
        int(hours), int(minutes), seconds))
    logger.write("AUROC: %.8f\n" % (auroc))
    logger.write("Elapsed Time: {:0>2}:{:0>2}:{:05.2f}\n".format(
        int(hours), int(minutes), seconds))
def main():
    logger, result_dir, dir_name = utils.config_backup_get_log(
        args, "descriptor", __file__)
    device = utils.get_device()
    utils.set_seed(args.seed, device)  # set random seed

    chpt_name = 'descriptor_%s_%s_%d_%d' % (args.dataname, args.method,
                                            args.nglobal, args.nlocal)
    print('Checkpoint name:', chpt_name)

    # load data
    assert args.dataname in ['Haywrd', 'ykdelB']
    dataconfig = configs.dataconfig[args.sartype.lower()]
    traindata = dataconfig[args.dataname]['traindata']
    testdata = dataconfig[args.dataname]['testdata']

    trainloader = loader_grd.setup_trainloader(traindata, args)
    testloader = loader_grd.setup_testloader(traindata, args)
    testloader_real = loader_grd.setup_testloader(testdata, args)
    gtmat = loader_grd.load_gtmat(args.sartype,
                                  dataconfig[args.dataname]['gtmat'])

    # setup architecture
    net = model.dcsnn_full(args.nglobal,
                           args.nlocal,
                           pretrained=True,
                           K=args.moco_k,
                           m=args.moco_m,
                           T=args.temp,
                           no_moco=args.no_moco,
                           ntrain=len(traindata)).to(device)
    # matcher = SuperGlue(sinkhorn_iterations=100, match_threshold=0.2).to(device)

    # setup criterion and optimizer
    criterion = losses.setup_loss(args.method, device, lamda = (args.lamreg,args.lamkp,args.lamld),\
                                  temp = args.temp, self_learning = not args.no_selflearning)
    optimizer = torch.optim.SGD(net.parameters(),
                                lr=args.lr,
                                weight_decay=1e-4,
                                momentum=0.)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=[15, 25],
                                                     gamma=0.1)

    print('==> Start training ..')
    start = time.time()
    best_mAP = -1

    for epoch in range(args.maxepoch):
        train(epoch, net, trainloader, criterion, optimizer, device, logger)
        # print("after train \n", torch.cuda.memory_allocated())
        scheduler.step()  # update optimizer lr
        if epoch % 10 == 0:
            mAP, _ = test(gtmat, testloader, testloader_real, device, net)
            logger.write("mAP: %0.6f" % (mAP))
            if mAP > best_mAP:
                best_mAP = mAP
                best_epoch = epoch
                state = {
                    'method': args.method,
                    'net': net.state_dict(),
                    'mAP': mAP,
                    'args': vars(args)
                }
                if result_dir is not None:
                    torch.save(state, './%s/%s.pth' % (result_dir, chpt_name))

    logger.write("Best mAP: %.6f (Epoch: %d)" % (best_mAP, best_epoch))

    if args.batchout:
        with open('temp_result.txt', 'w') as f:
            f.write("%10.8f\n" % (best_mAP))
            f.write("%d" % (best_epoch))

    end = time.time()
    hours, rem = divmod(end - start, 3600)
    minutes, seconds = divmod(rem, 60)
    logger.write("Elapsed Time: {:0>2}:{:0>2}:{:05.2f}".format(
        int(hours), int(minutes), seconds))