コード例 #1
0
def main(_):
    args = parser.arg_parse()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        model = UNet(batch_size=args.batch_size,
                     cns_embedding_size=args.cns_embedding_size)
        model.register_session(sess)
        model.build_model(is_training=False, inst_norm=args.inst_norm)
        embedding_ids = [int(i) for i in args.embedding_ids.split(",")]
        if not args.interpolate:
            if len(embedding_ids) == 1:
                embedding_ids = embedding_ids[0]
            model.infer(model_dir=args.model_dir,
                        source_obj=args.source_obj,
                        embedding_ids=embedding_ids,
                        save_dir=args.save_dir)
        else:
            if len(embedding_ids) < 2:
                raise Exception(
                    "no need to interpolate yourself unless you are a narcissist"
                )
            chains = embedding_ids[:]
            if args.uroboros:
                chains.append(chains[0])
            pairs = list()
            for i in range(len(chains) - 1):
                pairs.append((chains[i], chains[i + 1]))
            for s, e in pairs:
                model.interpolate(model_dir=args.model_dir,
                                  source_obj=args.source_obj,
                                  between=[s, e],
                                  save_dir=args.save_dir,
                                  steps=args.steps)
コード例 #2
0
def main(_):
    args = parser.arg_parse()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        model = UNet(args.experiment_dir,
                     batch_size=args.batch_size,
                     experiment_id=args.experiment_id,
                     input_width=args.image_size,
                     output_width=args.image_size,
                     L1_penalty=args.L1_penalty,
                     cns_embedding_size=args.cns_embedding_size)
        model.register_session(sess)
        model.build_model(is_training=True, inst_norm=args.inst_norm)
        model.train(lr=args.lr,
                    epoch=args.epoch,
                    resume=args.resume,
                    schedule=args.schedule,
                    flip_labels=args.flip_labels,
                    sample_steps=args.sample_steps)
コード例 #3
0
def main(_):
    args = parser.arg_parse()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        model = UNet(args.experiment_dir,
                     batch_size=args.batch_size,
                     experiment_id=args.experiment_id,
                     input_width=args.image_size,
                     output_width=args.image_size,
                     embedding_num=args.embedding_num,
                     L1_penalty=args.L1_penalty,
                     Lconst_penalty=args.Lconst_penalty,
                     Ltv_penalty=args.Ltv_penalty,
                     Lcategory_penalty=args.Lcategory_penalty,
                     cns_encoder_dir=args.cns_encoder_dir,
                     cns_embedding_size=args.cns_embedding_size)
        model.register_session(sess)
        if args.flip_labels:
            model.build_model(is_training=True,
                              inst_norm=args.inst_norm,
                              no_target_source=True)
        else:
            model.build_model(is_training=True, inst_norm=args.inst_norm)
        fine_tune_list = None
        if args.fine_tune:
            ids = args.fine_tune.split(",")
            fine_tune_list = set([int(i) for i in ids])
        model.train(lr=args.lr,
                    epoch=args.epoch,
                    resume=args.resume,
                    schedule=args.schedule,
                    freeze_encoder=args.freeze_encoder,
                    fine_tune=fine_tune_list,
                    sample_steps=args.sample_steps,
                    flip_labels=args.flip_labels)
コード例 #4
0
ファイル: main.py プロジェクト: caleb-llh/dlminiproj
    if not os.path.exists(args.saved_pkl_dir):
        os.makedirs(args.saved_pkl_dir)

    ## prepare data
    transform = transforms.Compose([transforms.CenterCrop(280),
                                    transforms.ToTensor(),
                                    transforms.Normalize(MEAN, STD)
                                    ])
    trainset = data.PascalVOC(args.data_dir,'train',transform)
    validset = data.PascalVOC(args.data_dir,'val',transform)
    
    loadertr = torch.utils.data.DataLoader(trainset,batch_size=args.train_batch,shuffle=True)
    loadervl = torch.utils.data.DataLoader(validset,batch_size=args.test_batch,shuffle=False)
    
    ## run train/results
    if args.run == 'train':
        train(device, loadertr, loadervl)
        print("\nFinished training.")
    if args.run == 'results':
        results(device, loadervl, validset)
        print("\nFinished producing results. Tail accuracy, top 5 and bottom 5 images are in saved_img folder")
    if args.run == 'plot':
        utils.plot(args)
        print("\nFinished plotting.")

if __name__=='__main__':
    args = parser.arg_parse()
    main()
    
    
    
コード例 #5
0
    print("----------------------------------------------------------")
    print("{:25s}: {}".format("Task", "Time Taken (in seconds)"))
    print()
    print("{:25s}: {:2.3f}".format("Reading addresses", load_batch - read_dir))
    print("{:25s}: {:2.3f}".format("Loading batch",
                                   start_det_loop - load_batch))
    print("{:25s}: {:2.3f}".format(
        "Detection (" + str(len(imlist)) + " images)", output_recast))
    print("{:25s}: {:2.3f}".format("Output Processing",
                                   class_load - start_det_loop))
    print("{:25s}: {:2.3f}".format("Drawing Boxes", end - draw))
    print("{:25s}: {:2.3f}".format("Average time_per_img",
                                   (end - load_batch) / len(imlist)))
    print("----------------------------------------------------------")

    torch.cuda.empty_cache()


if __name__ == '__main__':
    args = arg_parse()
    # images = args.images
    # batch_size = int(args.bs)
    # confidence = float(args.confidence)
    # nms_thesh = float(args.nms_thresh)
    # start = 0

    num_classes = 80  # For COCO
    classes = load_classes("data/coco.names")

    detection_loop(args, num_classes)