コード例 #1
0
def main(args):
    logger, result_dir, dir_name = utils.config_backup_get_log(args, __file__)
    csv_name = os.path.join(result_dir, 'rate_distortion.csv')
    file = open(csv_name, 'w', newline='')
    csvwriter = csv.writer(file)
    csvwriter.writerow(['beta', args.beta])

    device = utils.get_device()
    utils.set_seed(args.seed, device)

    trainloader = dataset.get_trainloader(args.data, args.dataroot, args.type,
                                          args.target, args.bstrain,
                                          args.nworkers)

    import models
    model = models.get_vae(args.data, L=10).to(device)

    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           betas=(0.5, 0.999),
                           weight_decay=1e-4)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.milestones,
                                               gamma=0.1)

    chpt_name = 'betaVAE_%s_target%s_seed%s.pth' % (args.data, str(
        args.target), str(args.seed))
    chpt_name = os.path.join("./chpt", chpt_name)

    print('==> Start training ..')
    start = time.time()
    loss, distortion, rate = test(model, trainloader, logger, device,
                                  args.beta)
    csvwriter.writerow([-1, distortion, rate, distortion + rate])
    print('EPOCH %3d LOSS: %.4F, DISTORTION: %.4f, RATE: %.4f, D+R: %.4f' %
          (-1, loss, distortion, rate, distortion + rate))
    for epoch in range(args.maxepoch):
        loss, distortion, rate = train(epoch, model, trainloader, optimizer,
                                       scheduler, logger, device, args.beta)
        csvwriter.writerow([epoch, distortion, rate, distortion + rate])
        print('EPOCH %3d LOSS: %.4F, DISTORTION: %.4f, RATE: %.4f, D+R: %.4f' %
              (epoch, loss, distortion, rate, distortion + rate))

    file.close()
    end = time.time()
    hours, rem = divmod(end - start, 3600)
    minutes, seconds = divmod(rem, 60)
    print("Elapsed Time: {:0>2}:{:0>2}:{:05.2f}".format(
        int(hours), int(minutes), seconds))
    logger.write("Elapsed Time: {:0>2}:{:0>2}:{:05.2f}\n".format(
        int(hours), int(minutes), seconds))

    if args.batchout:
        f = open('temp_result.txt', 'w', newline='')
        f.close()
コード例 #2
0
def main(args):
    logger, result_dir, dir_name = utils.config_backup_get_log(args,__file__)
    device = utils.get_device()
    utils.set_seed(args.seed, device)

    trainloader = dataset.get_trainloader(args.data, args.dataroot, args.target, args.bstrain, args.nworkers)
    testloader = dataset.get_testloader(args.data, args.dataroot, args.target, args.bstest, args.nworkers)
    
    import models
    encoder, decoder, discriminator = models.get_aae(args.data)
    encoder.to(device)
    decoder.to(device)
    discriminator.to(device)

    # Use binary cross-entropy loss
    adversarial_loss = torch.nn.BCELoss().to(device)
    pixelwise_loss = torch.nn.L1Loss().to(device)

    optimizer_G = torch.optim.Adam(itertools.chain(encoder.parameters(), decoder.parameters()), lr=args.lr, betas=(0.5, 0.999), weight_decay=1e-4)
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=args.lr, betas=(0.5, 0.999), weight_decay=1e-4)

    scheduler_G = optim.lr_scheduler.MultiStepLR(optimizer_G, milestones=args.milestones, gamma=0.1)   
    scheduler_D = optim.lr_scheduler.MultiStepLR(optimizer_D, milestones=args.milestones, gamma=0.1)   

    chpt_name = 'AAE_%s_target%s_seed%s.pth'%(args.data, str(args.target), str(args.seed))
    chpt_name = os.path.join("./chpt",chpt_name)

    print('==> Start training ..')   
    best_auroc = 0.
    start = time.time()
    for epoch in range(args.maxepoch):
        train(epoch, adversarial_loss, pixelwise_loss, encoder, decoder, discriminator, trainloader, optimizer_G, optimizer_D, scheduler_G, scheduler_D, logger, device)

    auroc, aupr, _ = test(encoder, decoder, testloader, device)
    print('Epoch: %4d AUROC: %.4f AUPR: %.4f'%(epoch, auroc, aupr))
    logger.write('Epoch: %4d AUROC: %.4f AUPR: %.4f \n'%(epoch, auroc, aupr))
    state = {
        'encoder': encoder.state_dict(), 
        'decoder': decoder.state_dict(), 
        'discriminator': discriminator.state_dict(), 
        'auroc': auroc, 
        'epoch': epoch}
    torch.save(state, chpt_name)

    end = time.time()
    hours, rem = divmod(end-start, 3600)
    minutes, seconds = divmod(rem, 60)
    print('AUROC... ', auroc)
    print("Elapsed Time: {:0>2}:{:0>2}:{:05.2f}".format(int(hours),int(minutes),seconds))
    logger.write("AUROC: %.8f\n"%(auroc))
    logger.write("Elapsed Time: {:0>2}:{:0>2}:{:05.2f}\n".format(int(hours),int(minutes),seconds))
コード例 #3
0
def main(args):
    logger, result_dir, dir_name = utils.config_backup_get_log(args, __file__)
    device = utils.get_device()
    utils.set_seed(args.seed, device)

    trainloader = dataset.get_trainloader(args.data, args.dataroot,
                                          args.target, args.bstrain,
                                          args.nworkers)
    testloader = dataset.get_testloader(args.data, args.dataroot, args.target,
                                        args.bstest, args.nworkers)

    import models
    model = models.get_pgn_encoder(args.data, args.dropoutp).to(device)

    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           betas=(0.5, 0.999),
                           weight_decay=1e-4)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.milestones,
                                               gamma=0.1)

    chpt_name = 'GPN_%s_target%s_seed%s.pth' % (args.data, str(
        args.target), str(args.seed))
    chpt_name = os.path.join("./chpt", chpt_name)

    print('==> Start training ..')
    start = time.time()

    for epoch in range(args.maxepoch):
        train(epoch, model, trainloader, optimizer, scheduler, logger, device)

    auroc, aupr, _ = test(model, testloader, args.mcdropoutT, device)
    print('Epoch: %4d AUROC: %.6f AUPR: %.6f' % (epoch, auroc, aupr))
    logger.write('Epoch: %4d AUROC: %.6f AUPR: %.6f \n' % (epoch, auroc, aupr))
    state = {'model': model.state_dict(), 'auroc': auroc, 'epoch': epoch}
    torch.save(state, chpt_name)

    end = time.time()
    hours, rem = divmod(end - start, 3600)
    minutes, seconds = divmod(rem, 60)
    print('AUROC... ', auroc)
    print("Elapsed Time: {:0>2}:{:0>2}:{:05.2f}".format(
        int(hours), int(minutes), seconds))
    logger.write("AUROC: %.8f\n" % (auroc))
    logger.write("Elapsed Time: {:0>2}:{:0>2}:{:05.2f}\n".format(
        int(hours), int(minutes), seconds))
コード例 #4
0
def main(args):
    logger, result_dir, dir_name = utils.config_backup_get_log(args, __file__)
    device = utils.get_device()
    utils.set_seed(args.seed, device)

    trainloader = dataset.get_trainloader(args.data, args.dataroot,
                                          args.target, args.bstrain,
                                          args.nworkers)
    testloader = dataset.get_testloader(args.data, args.dataroot, args.target,
                                        args.bstest, args.nworkers)

    import models
    encoder, generator, discriminator, discriminator_z = models.get_gpnd(
        args.data)
    encoder.to(device)
    generator.to(device)
    discriminator.to(device)
    discriminator_z.to(device)

    optimizer_G = optim.Adam(generator.parameters(),
                             lr=args.lr,
                             betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(),
                             lr=args.lr,
                             betas=(0.5, 0.999))
    optimizer_E = optim.Adam(encoder.parameters(),
                             lr=args.lr,
                             betas=(0.5, 0.999))
    optimizer_GE = optim.Adam(list(encoder.parameters()) +
                              list(generator.parameters()),
                              lr=args.lr,
                              betas=(0.5, 0.999))
    optimizer_ZD = optim.Adam(discriminator_z.parameters(),
                              lr=args.lr,
                              betas=(0.5, 0.999))

    scheduler_G = optim.lr_scheduler.MultiStepLR(optimizer_G,
                                                 milestones=args.milestones,
                                                 gamma=args.gamma)
    scheduler_D = optim.lr_scheduler.MultiStepLR(optimizer_D,
                                                 milestones=args.milestones,
                                                 gamma=args.gamma)
    scheduler_E = optim.lr_scheduler.MultiStepLR(optimizer_E,
                                                 milestones=args.milestones,
                                                 gamma=args.gamma)
    scheduler_GE = optim.lr_scheduler.MultiStepLR(optimizer_GE,
                                                  milestones=args.milestones,
                                                  gamma=args.gamma)
    scheduler_ZD = optim.lr_scheduler.MultiStepLR(optimizer_ZD,
                                                  milestones=args.milestones,
                                                  gamma=args.gamma)
    schedulers = (scheduler_G, scheduler_D, scheduler_E, scheduler_GE,
                  scheduler_ZD)
    chpt_name = 'GPND_%s_target%s_seed%s.pth' % (args.data, str(
        args.target), str(args.seed))
    chpt_name = os.path.join("./chpt", chpt_name)

    print('==> Start training ..')
    start = time.time()
    for epoch in range(args.maxepoch):
        train(epoch, encoder, generator, discriminator, discriminator_z,
              trainloader, optimizer_G, optimizer_D, optimizer_E, optimizer_GE,
              optimizer_ZD, schedulers, logger, device)
        if epoch > 79 and epoch % 20 == 0:
            auroc, aupr, _ = test(encoder, generator, testloader, device)
            print(auroc)

    auroc, aupr, _ = test(encoder, generator, testloader, device)
    print('Epoch: %4d AUROC: %.4f AUPR: %.4f' % (epoch, auroc, aupr))
    state = {
        'encoder': encoder.state_dict(),
        'generator': generator.state_dict(),
        'discriminator': discriminator.state_dict(),
        'discriminator_z': discriminator_z.state_dict(),
        'auroc': auroc,
        'epoch': epoch
    }
    torch.save(state, chpt_name)

    end = time.time()
    hours, rem = divmod(end - start, 3600)
    minutes, seconds = divmod(rem, 60)
    print('AUROC... ', auroc)
    print("Elapsed Time: {:0>2}:{:0>2}:{:05.2f}".format(
        int(hours), int(minutes), seconds))
    logger.write("AUROC: %.8f\n" % (auroc))
    logger.write("Elapsed Time: {:0>2}:{:0>2}:{:05.2f}\n".format(
        int(hours), int(minutes), seconds))