示例#1
0
def get_pretrained_models(args_speaker, args_speech):
    args_all = {"speaker": args_speaker, "speech": args_speech}
    models = {}
    for key, args in args_all.items():
        CNN_arch = get_dict_from_args([
            'cnn_input_dim', 'cnn_N_filt', 'cnn_len_filt', 'cnn_max_pool_len',
            'cnn_use_laynorm_inp', 'cnn_use_batchnorm_inp', 'cnn_use_laynorm',
            'cnn_use_batchnorm', 'cnn_act', 'cnn_drop'
        ], args.cnn)

        DNN_arch = get_dict_from_args([
            'fc_input_dim', 'fc_lay', 'fc_drop', 'fc_use_batchnorm',
            'fc_use_laynorm', 'fc_use_laynorm_inp', 'fc_use_batchnorm_inp',
            'fc_act'
        ], args.dnn)

        Classifier = get_dict_from_args([
            'fc_input_dim', 'fc_lay', 'fc_drop', 'fc_use_batchnorm',
            'fc_use_laynorm', 'fc_use_laynorm_inp', 'fc_use_batchnorm_inp',
            'fc_act'
        ], args.classifier)

        CNN_arch['fs'] = args.windowing.fs
        model = SincClassifier(CNN_arch, DNN_arch, Classifier)
        if args.model_path != 'none':
            print("load model from:", args.model_path)
            if os.path.splitext(args.model_path)[1] == '.pkl':
                checkpoint_load = torch.load(args.model_path)
                model.load_raw_state_dict(checkpoint_load)
            else:
                load_checkpoint(model, args.model_path, strict=True)

        model = model.cuda().eval()
        # freeze the model
        for p in model.parameters():
            p.requires_grad = False
        models[key] = model

    return models
示例#2
0
def main():
    ############################
    # argument setup
    ############################
    args, cfg = setup_args_and_config()

    if args.show:
        print("### Run Argv:\n> {}".format(' '.join(sys.argv)))
        print("### Run Arguments:")
        s = dump_args(args)
        print(s + '\n')
        print("### Configs:")
        print(cfg.dumps())
        sys.exit()

    timestamp = utils.timestamp()
    unique_name = "{}_{}".format(timestamp, args.name)
    cfg['unique_name'] = unique_name  # for save directory
    cfg['name'] = args.name

    utils.makedirs('logs')
    utils.makedirs(Path('checkpoints', unique_name))

    # logger
    logger_path = Path('logs', f"{unique_name}.log")
    logger = Logger.get(file_path=logger_path,
                        level=args.log_lv,
                        colorize=True)

    # writer
    image_scale = 0.6
    writer_path = Path('runs', unique_name)
    if args.tb_image:
        writer = utils.TBWriter(writer_path, scale=image_scale)
    else:
        image_path = Path('images', unique_name)
        writer = utils.TBDiskWriter(writer_path, image_path, scale=image_scale)

    # log default informations
    args_str = dump_args(args)
    logger.info("Run Argv:\n> {}".format(' '.join(sys.argv)))
    logger.info("Args:\n{}".format(args_str))
    logger.info("Configs:\n{}".format(cfg.dumps()))
    logger.info("Unique name: {}".format(unique_name))

    # seed
    np.random.seed(cfg['seed'])
    torch.manual_seed(cfg['seed'])
    random.seed(cfg['seed'])

    if args.deterministic:
        #  https://discuss.pytorch.org/t/how-to-get-deterministic-behavior/18177/16
        #  https://pytorch.org/docs/stable/notes/randomness.html
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        cfg['n_workers'] = 0
        logger.info("#" * 80)
        logger.info("# Deterministic option is activated !")
        logger.info("#" * 80)
    else:
        torch.backends.cudnn.benchmark = True

    ############################
    # setup dataset & loader
    ############################
    logger.info("Get dataset ...")

    # setup language dependent values
    content_font, n_comp_types, n_comps = setup_language_dependent(cfg)

    # setup transform
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize([0.5], [0.5])])

    # setup data
    hdf5_data, meta = setup_data(cfg, transform)

    # setup dataset
    trn_dset, loader = get_dset_loader(hdf5_data,
                                       meta['train']['fonts'],
                                       meta['train']['chars'],
                                       transform,
                                       True,
                                       cfg,
                                       content_font=content_font)

    logger.info("### Training dataset ###")
    logger.info("# of avail fonts = {}".format(trn_dset.n_fonts))
    logger.info(f"Total {len(loader)} iterations per epochs")
    logger.info("# of avail items = {}".format(trn_dset.n_avails))
    logger.info(f"#fonts = {trn_dset.n_fonts}, #chars = {trn_dset.n_chars}")

    val_loaders = setup_cv_dset_loader(hdf5_data, meta, transform,
                                       n_comp_types, content_font, cfg)
    sfuc_loader = val_loaders['SeenFonts-UnseenChars']
    sfuc_dset = sfuc_loader.dataset
    ufsc_loader = val_loaders['UnseenFonts-SeenChars']
    ufsc_dset = ufsc_loader.dataset
    ufuc_loader = val_loaders['UnseenFonts-UnseenChars']
    ufuc_dset = ufuc_loader.dataset

    logger.info("### Cross-validation datasets ###")
    logger.info("Seen fonts, Unseen chars | "
                "#items = {}, #fonts = {}, #chars = {}, #steps = {}".format(
                    len(sfuc_dset), len(sfuc_dset.fonts), len(sfuc_dset.chars),
                    len(sfuc_loader)))
    logger.info("Unseen fonts, Seen chars | "
                "#items = {}, #fonts = {}, #chars = {}, #steps = {}".format(
                    len(ufsc_dset), len(ufsc_dset.fonts), len(ufsc_dset.chars),
                    len(ufsc_loader)))
    logger.info("Unseen fonts, Unseen chars | "
                "#items = {}, #fonts = {}, #chars = {}, #steps = {}".format(
                    len(ufuc_dset), len(ufuc_dset.fonts), len(ufuc_dset.chars),
                    len(ufuc_loader)))

    ############################
    # build model
    ############################
    logger.info("Build model ...")
    # generator
    g_kwargs = cfg.get('g_args', {})
    gen = MACore(1,
                 cfg['C'],
                 1,
                 **g_kwargs,
                 n_comps=n_comps,
                 n_comp_types=n_comp_types,
                 language=cfg['language'])
    gen.cuda()
    gen.apply(weights_init(cfg['init']))

    d_kwargs = cfg.get('d_args', {})
    disc = Discriminator(cfg['C'], trn_dset.n_fonts, trn_dset.n_chars,
                         **d_kwargs)
    disc.cuda()
    disc.apply(weights_init(cfg['init']))

    if cfg['ac_w'] > 0.:
        C = gen.mem_shape[0]
        aux_clf = AuxClassifier(C, n_comps, **cfg['ac_args'])
        aux_clf.cuda()
        aux_clf.apply(weights_init(cfg['init']))
    else:
        aux_clf = None
        assert cfg[
            'ac_gen_w'] == 0., "ac_gen loss is only available with ac loss"

    # setup optimizer
    g_optim = optim.Adam(gen.parameters(),
                         lr=cfg['g_lr'],
                         betas=cfg['adam_betas'])
    d_optim = optim.Adam(disc.parameters(),
                         lr=cfg['d_lr'],
                         betas=cfg['adam_betas'])
    ac_optim = optim.Adam(aux_clf.parameters(), lr=cfg['g_lr'], betas=cfg['adam_betas']) \
               if aux_clf is not None else None

    # resume checkpoint
    st_step = 1
    if args.resume:
        st_step, loss = load_checkpoint(args.resume, gen, disc, aux_clf,
                                        g_optim, d_optim, ac_optim)
        logger.info(
            "Resumed checkpoint from {} (Step {}, Loss {:7.3f})".format(
                args.resume, st_step - 1, loss))
    if args.finetune:
        load_gen_checkpoint(args.finetune, gen)

    ############################
    # setup validation
    ############################
    evaluator = Evaluator(hdf5_data,
                          trn_dset.avails,
                          logger,
                          writer,
                          cfg['batch_size'],
                          content_font=content_font,
                          transform=transform,
                          language=cfg['language'],
                          val_loaders=val_loaders,
                          meta=meta)
    if args.debug:
        evaluator.n_cv_batches = 10
        logger.info("Change CV batches to 10 for debugging")

    ############################
    # start training
    ############################
    trainer = Trainer(gen, disc, g_optim, d_optim, aux_clf, ac_optim, writer,
                      logger, evaluator, cfg)
    trainer.train(loader, st_step)
示例#3
0
    batch_jornet = convert_data_to_batch(data_crop)
    output_jornet = jornet(batch_jornet)
    jornet_joints_mainout = output_jornet[7][0].data.cpu().numpy()
    # plot depth
    jornet_joints_mainout *= 1.1
    jornet_joints_global = get_jornet_global_depth(jornet_joints_mainout, handroot)
    plot_jornet_joints_global_depth(jornet_joints_global, args.input_img_namebase, gt_joints=labels_jointspace)
    joints_colorspace = joints_globaldepth_to_colorspace(jornet_joints_global, handroot, img_res=(640, 480))
    plot_jornet_colorspace(joints_colorspace, args.input_img_namebase)
    return output_halnet, output_jornet, jornet_joints_global


# load nets
print('Loading HALNet from: ' + args.halnet_filepath)
halnet, _, _, _ = trainer.load_checkpoint(filename=args.halnet_filepath,
                                          model_class=HALNet.HALNet,
                                          use_cuda=args.use_cuda)
print('Loading JORNet from: ' + args.jornet_filepath)
jornet, _, _, _ = trainer.load_checkpoint(filename=args.jornet_filepath,
                                          model_class=JORNet.JORNet,
                                          use_cuda=args.use_cuda)


if args.input_img_namebase == '':
    predict_from_dataset(args, halnet, jornet)
elif args.dataset_folder == '':
    raise('You need to define either a dataset folder (-r) or an image file name base (-i)')
else:
    for i in range(10):
        args.input_img_namebase = args.input_img_namebase[0:-1] + str(i)
        predict_from_image(args, halnet, jornet)
示例#4
0
def main(args):
    speaker_cfg = args.speaker_cfg
    speech_cfg = args.speech_cfg
    args_speaker = read_conf(speaker_cfg, deepcopy(args))
    args_speaker.model_path = args.speaker_model
    args_speech = read_conf(speech_cfg, deepcopy(args))
    args_speech.model_path = args.speech_model

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print("set seed: ", args_speaker.optimization.seed)
    torch.manual_seed(args_speaker.optimization.seed)
    np.random.seed(args_speaker.optimization.seed)
    random.seed(args_speaker.optimization.seed)

    torch.cuda.set_device(args.local_rank)
    if not args.no_dist:
        torch.distributed.init_process_group(backend="nccl")

    train_dataset = TIMIT_speaker(args.data_root,
                                  train=True,
                                  phoneme=True,
                                  norm_factor=True)
    test_dataset = TIMIT_speaker(args.data_root,
                                 train=False,
                                 phoneme=True,
                                 norm_factor=True)

    pretrained_models = get_pretrained_models(args_speaker, args_speech)

    loss_factors = {
        "speaker": args.speaker_factor,
        "speech": args.speech_factor,
        "norm": args.norm_factor
    }
    if args.target < 0:  # non-targeted
        speaker_loss = SpeakerLoss(pretrained_models['speaker'])
    else:  # targeted attack
        speaker_loss = SpeakerLossTarget(pretrained_models['speaker'],
                                         args.target)
    loss_all = {}
    loss_all['speech'] = {
        'model':
        pretrained_models['speech'],
        'factor':
        loss_factors['speech'],
        'loss_func':
        SpeechLoss(pretrained_models['speech'],
                   factor_kld=args.speech_kld_factor)
    }
    loss_all['speaker'] = {
        'model': pretrained_models['speaker'],
        'factor': loss_factors['speaker'],
        'loss_func': speaker_loss
    }
    loss_all['norm'] = {
        'loss_func': MSEWithThreshold(args.norm_clip),
        'factor': loss_factors['norm']
    }

    cost = AdversarialLoss(loss_all)

    model = SpeechTransformer(args.channel, args.kernel_size, args.dilation,
                              args.sample, args.noise_scale)

    if args.pt_file != 'none':
        print("load model from:", args.pt_file)
        if os.path.splitext(args.pt_file)[1] == '.pkl':
            checkpoint_load = torch.load(args.pt_file)
            model.load_raw_state_dict(checkpoint_load)
        else:
            load_checkpoint(model, args.pt_file)

    model = model.cuda()
    if args.eval:
        assert args.pt_file != 'none', "no pretrained model is provided!"
        print('only eval the model')
        evaluate(model, test_dataset, cost)
        return
    if args.test:
        assert args.pt_file != 'none', "no pretrained model is provided!"
        print("only test the model")
        filename_list = open("./data/TIMIT/speaker/test.scp", 'r').readlines()
        filename_list = [_f.strip() for _f in filename_list]
        label_dict = np.load(
            os.path.join(args.data_root, "processed",
                         "TIMIT_labels.npy")).item()
        test_wav(model, filename_list, args.data_root,
                 os.path.join(args.data_root, "output"),
                 pretrained_models['speaker'], label_dict, args.target)
        return
    if args.cpu_test:
        assert args.pt_file != 'none', "no pretrained model is provided!"
        print("only cpu test the model")
        filename_list = open("./data/TIMIT/speaker/test.scp", 'r').readlines()
        filename_list = [_f.strip() for _f in filename_list]
        label_dict = np.load(
            os.path.join(args.data_root, "processed",
                         "TIMIT_labels.npy")).item()
        test_wav_cpu(model, filename_list, args.data_root,
                     os.path.join(args.data_root, "output"),
                     pretrained_models['speaker'], label_dict, args.target)
        return

    print("train the model")
    batch_process = batch_process_speaker
    eval_hook = EvalHook()
    optimizer = optim.Adam(model.parameters(),
                           lr=args_speaker.optimization.lr,
                           betas=(0.95, 0.999))
    lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 2, 0.5)
    if args.no_dist:
        kwarg = {
            'shuffle':
            True,
            'worker_init_fn':
            partial(_init_fn, seed=args_speaker.optimization.seed)
        }
    else:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
        kwarg = {
            'sampler':
            train_sampler,
            'worker_init_fn':
            partial(_init_fn, seed=args_speaker.optimization.seed)
        }
    train_dataloader = DataLoader(train_dataset,
                                  args_speaker.optimization.batch_size,
                                  num_workers=args.num_workers,
                                  pin_memory=True,
                                  **kwarg)
    test_dataloader = DataLoader(test_dataset,
                                 args_speaker.optimization.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers,
                                 pin_memory=True)
    trainer = ClassifierTrainer(
        model,
        train_dataloader,
        optimizer,
        cost,
        batch_process,
        args.output_dir,
        0,
        test_dataloader,
        eval_hook=eval_hook,
        eval_every=args_speaker.optimization.N_eval_epoch,
        print_every=args_speaker.optimization.print_every,
        lr_scheduler=lr_scheduler)
    trainer.logger.info(args)
    trainer.run(args_speaker.optimization.N_epochs)
示例#5
0
dataset_root_folder = 'C:/Users/Administrator/Documents/Datasets/fpa_benchmark/'
gt_folder = 'Hand_pose_annotation_v1'
data_folder = 'video_files'
subject = 'Subject_1'
actions = ['charge_cell_phone',
           'clean_glasses',
           'close_juice_bottle',
           'close_liquid_soap',
           'close_milk',
           'close_peanut_butter',
           'drink_mug',
           'flip_pages']
sequence = '1'

model, _, _, _ = trainer.load_checkpoint(args.checkpoint_filename, HALNet, use_cuda=True)
if args.use_cuda:
    model.cuda()

fig = None
for action in actions:
    for seq in range(3):
        visualize.close_fig(fig)
        fig = visualize.create_fig()
        seq_str = str(seq+1)
        curr_data_folder = '/'.join([dataset_root_folder, data_folder, subject, action, seq_str])
        depth_imgs = []
        for i in range(99):
            if i < 10:
                frame_num = '000' + str(i)
            else:
        'best_loss': 1e10,
        'losses': [],
        'output_filepath': 'log.txt',
        'tot_epoch': args.num_epochs,
        'split_filename': args.split_filename,
        'fpa_subj_split': args.fpa_subj_split,
        'fpa_obj_split': args.fpa_obj_split,
        'dataset_root_folder': args.dataset_root_folder,
        'epoch': 1,
        'batch_idx': -1
    }
    continue_to_batch = False
else:
    print('Loading model from checkpoint: {}'.format(args.checkpoint_filename))
    model, _, train_vars, _ = trainer.load_checkpoint(args.checkpoint_filename, JORNet_light,
                                             use_cuda=True,
                                             fpa_subj=args.fpa_subj_split,
                                             num_channels=2)
    if args.use_cuda:
        model.cuda()
    train_vars['split_filename'] = args.split_filename
    train_vars['fpa_subj_split'] = args.fpa_subj_split
    train_vars['fpa_obj_split'] = args.fpa_obj_split
    train_vars['dataset_root_folder'] = args.dataset_root_folder
    continue_to_batch = True

model.train()

train_loader = fpa_dataset.DataLoaderPoseRegressionFromVQVAE(root_folder=train_vars['dataset_root_folder'],
                                              type='train',
                                              input_type="rgbd",
                                              transform_color=transform_color,
    root_folder=args.dataset_root_folder,
    type='test',
    input_type="depth",
    transform_color=transform_color,
    transform_depth=transform_depth,
    batch_size=args.batch_size,
    split_filename=args.split_filename,
    fpa_subj_split=args.fpa_subj_split)

print('Length of dataset: {}'.format(len(test_loader.dataset)))

model_params_dict = {'joint_ixs': range(2)}

print('Loading model from checkpoint: {}'.format(args.checkpoint_filename))
model, _, _, _ = trainer.load_checkpoint(args.checkpoint_filename,
                                         JORNet_light,
                                         use_cuda=True,
                                         fpa_subj=args.fpa_subj_split)
if args.use_cuda:
    model.cuda()

loss_func = my_losses.cross_entropy_loss_p_logq
losses = []
for i in range(len(test_loader.dataset)):
    losses.append([])

train_vars = {
    'iter_size': 1,
    'total_loss': 0,
    'verbose': True,
    'checkpoint_filenamebase': 'checkpoint_test_fpa_subj',
    'checkpoint_filename': 'checkpoint_test_fpa_subj.pth.tar',
示例#8
0
文件: train.py 项目: yqGANs/lffont
def train(args, cfg, ddp_gpu=-1):
    cfg.gpu = ddp_gpu
    torch.cuda.set_device(ddp_gpu)
    cudnn.benchmark = True

    logger_path = cfg.work_dir / "logs" / "{}.log".format(cfg.unique_name)
    logger = Logger.get(file_path=logger_path, level="info", colorize=True)

    image_scale = 0.6
    writer_path = cfg.work_dir / "runs" / cfg.unique_name
    image_path = cfg.work_dir / "images" / cfg.unique_name
    writer = utils.TBDiskWriter(writer_path, image_path, scale=image_scale)

    args_str = dump_args(args)
    if is_main_worker(ddp_gpu):
        logger.info("Run Argv:\n> {}".format(" ".join(sys.argv)))
        logger.info("Args:\n{}".format(args_str))
        logger.info("Configs:\n{}".format(cfg.dumps()))
        logger.info("Unique name: {}".format(cfg.unique_name))

    logger.info("Get dataset ...")

    content_font = cfg.content_font
    n_comps = int(cfg.n_comps)

    trn_transform, val_transform = setup_transforms(cfg)

    env = load_lmdb(cfg.data_path)
    env_get = lambda env, x, y, transform: transform(read_data_from_lmdb(env, f'{x}_{y}')['img'])

    data_meta = load_json(cfg.data_meta)
    dec_dict = load_json(cfg.dec_dict)

    if cfg.phase == "comb":
        get_trn_loader = get_comb_trn_loader
        get_cv_loaders = get_cv_comb_loaders
        Trainer = CombinedTrainer

    elif cfg.phase == "fact":
        get_trn_loader = get_fact_trn_loader
        get_cv_loaders = get_cv_fact_loaders
        Trainer = FactorizeTrainer

    else:
        raise ValueError(cfg.phase)

    trn_dset, trn_loader = get_trn_loader(env,
                                          env_get,
                                          cfg,
                                          data_meta["train"],
                                          dec_dict,
                                          trn_transform,
                                          num_workers=cfg.n_workers,
                                          shuffle=True)

    if is_main_worker(ddp_gpu):
        cv_loaders = get_cv_loaders(env,
                                    env_get,
                                    cfg,
                                    data_meta,
                                    dec_dict,
                                    val_transform,
                                    num_workers=cfg.n_workers,
                                    shuffle=False)
    else:
        cv_loaders = None

    logger.info("Build model ...")
    # generator
    g_kwargs = cfg.get("g_args", {})
    g_cls = generator_dispatch()
    gen = g_cls(1, cfg.C, 1, **g_kwargs, n_comps=n_comps)
    gen.cuda()
    gen.apply(weights_init(cfg.init))

    if cfg.gan_w > 0.:
        d_kwargs = cfg.get("d_args", {})
        disc = disc_builder(cfg.C, trn_dset.n_fonts, trn_dset.n_unis, **d_kwargs)
        disc.cuda()
        disc.apply(weights_init(cfg.init))
    else:
        disc = None

    if cfg.ac_w > 0.:
        aux_clf = aux_clf_builder(gen.mem_shape, n_comps, **cfg.ac_args)
        aux_clf.cuda()
        aux_clf.apply(weights_init(cfg.init))
    else:
        aux_clf = None
        assert cfg.ac_gen_w == 0., "ac_gen loss is only available with ac loss"

    g_optim = optim.Adam(gen.parameters(), lr=cfg.g_lr, betas=cfg.adam_betas)
    d_optim = optim.Adam(disc.parameters(), lr=cfg.d_lr, betas=cfg.adam_betas) \
        if disc is not None else None
    ac_optim = optim.Adam(aux_clf.parameters(), lr=cfg.ac_lr, betas=cfg.adam_betas) \
        if aux_clf is not None else None

    st_step = 1
    if args.resume:
        st_step, loss = load_checkpoint(args.resume, gen, disc, aux_clf, g_optim, d_optim, ac_optim, cfg.overwrite)
        logger.info("Resumed checkpoint from {} (Step {}, Loss {:7.3f})".format(
            args.resume, st_step - 1, loss))
        if cfg.overwrite:
            st_step = 1
        else:
            pass

    evaluator = Evaluator(env,
                          env_get,
                          logger,
                          writer,
                          cfg.batch_size,
                          val_transform,
                          content_font,
                          use_half=cfg.use_half
                          )

    trainer = Trainer(gen, disc, g_optim, d_optim,
                      aux_clf, ac_optim,
                      writer, logger,
                      evaluator, cv_loaders,
                      cfg)

    trainer.train(trn_loader, st_step, cfg[f"{cfg.phase}_iter"])
示例#9
0
def parse_args(model_class):
    parser = argparse.ArgumentParser(
        description='Train a hand-tracking deep neural network')
    parser.add_argument('--num_iter',
                        dest='num_iter',
                        type=int,
                        help='Total number of iterations to train')
    parser.add_argument('-c',
                        dest='checkpoint_filepath',
                        default='',
                        required=True,
                        help='Checkpoint file from which to begin training')
    parser.add_argument('--log_interval',
                        type=int,
                        dest='log_interval',
                        default=10,
                        help='Number of iterations interval on which to log'
                        ' a model checkpoint (default 10)')
    parser.add_argument('-v',
                        '--verbose',
                        dest='verbose',
                        action='store_true',
                        default=True,
                        help='Verbose mode')
    parser.add_argument('--max_mem_batch',
                        type=int,
                        dest='max_mem_batch',
                        default=8,
                        help='Max size of batch given GPU memory (default 8)')
    parser.add_argument(
        '--batch_size',
        type=int,
        dest='batch_size',
        default=16,
        help=
        'Batch size for training (if larger than max memory batch, training will take '
        'the required amount of iterations to complete a batch')
    parser.add_argument('-r',
                        dest='root_folder',
                        default='',
                        required=True,
                        help='Root folder for dataset')
    parser.add_argument('--visual',
                        dest='visual_debugging',
                        action='store_true',
                        default=False,
                        help='Whether to visually inspect results')
    parser.add_argument('--cuda',
                        dest='use_cuda',
                        action='store_true',
                        default=False,
                        help='Whether to use cuda for training')
    parser.add_argument('--split_filename',
                        default='',
                        required=False,
                        help='Split filename for the file with dataset splits')
    args = parser.parse_args()

    control_vars, valid_vars = initialize_vars(args)
    control_vars['visual_debugging'] = args.visual_debugging

    print_verbose(
        "Loading model and optimizer from file: " + args.checkpoint_filepath,
        args.verbose)

    model, optimizer, valid_vars, train_control_vars = \
        trainer.load_checkpoint(filename=args.checkpoint_filepath, model_class=model_class, use_cuda=args.use_cuda)

    valid_vars['root_folder'] = args.root_folder
    valid_vars['use_cuda'] = args.use_cuda
    control_vars['log_interval'] = args.log_interval

    random_int_str = args.checkpoint_filepath.split('_')[-2]
    valid_vars['checkpoint_filenamebase'] = 'valid_halnet_log_' + str(
        random_int_str) + '_'
    control_vars[
        'output_filepath'] = 'validated_halnet_log_' + random_int_str + '.txt'
    msg = print_verbose(
        "Printing also to output filepath: " + control_vars['output_filepath'],
        args.verbose)
    with open(control_vars['output_filepath'], 'w+') as f:
        f.write(msg + '\n')

    if valid_vars['use_cuda']:
        print_verbose("Using CUDA", args.verbose)
    else:
        print_verbose("Not using CUDA", args.verbose)

    control_vars['num_epochs'] = 100
    control_vars['verbose'] = True

    if valid_vars['cross_entropy']:
        print_verbose("Using cross entropy loss", args.verbose)

    control_vars['num_iter'] = 0

    valid_vars['split_filename'] = args.split_filename

    return model, optimizer, control_vars, valid_vars, train_control_vars
示例#10
0
def main(_):
    """

    :return:
    """
    print("=" * 100)
    print("FLAGS")
    pp = pprint.PrettyPrinter()
    pp.pprint(flags.FLAGS.__flags)

    # make sub-directories
    if not os.path.isdir(FLAGS.checkpoint_dir):
        os.mkdir(FLAGS.checkpoint_dir)
    if not os.path.isdir(FLAGS.sample_dir):
        os.mkdir(FLAGS.sample_dir)
    if not os.path.isdir(FLAGS.test_dir):
        os.mkdir(FLAGS.test_dir)

    # Launch Graph
    sess = tf.Session()
    model = Pix2Pix(sess=sess,
                    gan_name=FLAGS.gan_name,
                    dataset_name=FLAGS.dataset_name,
                    input_size=FLAGS.input_size,
                    input_dim=FLAGS.input_dim,
                    output_size=FLAGS.output_size,
                    output_dim=FLAGS.output_dim,
                    batch_size=FLAGS.batch_size,
                    gen_num_filter=FLAGS.gen_num_filter,
                    disc_num_filter=FLAGS.disc_num_filter,
                    learning_rate=FLAGS.learning_rate,
                    beta1=FLAGS.beta1,
                    l1_lambda=FLAGS.l1_lambda,
                    checkpoint_dir=FLAGS.checkpoint_dir,
                    sample_dir=FLAGS.sample_dir,
                    test_dir=FLAGS.test_dir)

    sess.run(tf.global_variables_initializer())

    # show all variables
    model_vars = tf.trainable_variables()
    slim.model_analyzer.analyze_vars(model_vars, print_info=True)

    # load trained model
    flag_checkpoint, counter = load_checkpoint(model)

    dataset_dir = os.path.join("datasets", FLAGS.dataset_name)
    if FLAGS.train:
        # training dataset dir
        trainset_dir = os.path.join(dataset_dir, "train")
        valset_dir = os.path.join(dataset_dir, "val")
        run_train(model=model,
                  trainset_dir=trainset_dir,
                  valset_dir=valset_dir,
                  sample_size=FLAGS.batch_size,
                  scale_size=FLAGS.scale_size,
                  crop_size=FLAGS.crop_size,
                  flip=FLAGS.flip,
                  training_epochs=FLAGS.epoch,
                  flag_checkpoint=flag_checkpoint,
                  checkpoint_counter=counter)

    else:
        # test dir
        testset_dir = os.path.join(dataset_dir, "test")
        if not os.path.isdir(testset_dir):
            testset_dir = os.path.join(dataset_dir, "val")

        run_test(model=model, testset_dir=testset_dir)
def main(args):
    args = read_conf(args.cfg, args)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    if args.dataset == 'timit':
        train_dataset = TIMIT(data_root=args.data_root,
                              datalist_root=args.datalist_root,
                              train=True,
                              oversampling=args.oversampling)
        test_dataset = TIMIT(data_root=args.data_root,
                             datalist_root=args.datalist_root,
                             train=False)
    elif args.dataset == 'libri':
        raise NotImplementedError
    else:
        raise NotImplementedError

    cost = nn.NLLLoss()

    CNN_arch = {
        'input_dim': train_dataset.wlen,
        'fs': args.fs,
        'cnn_N_filt': args.cnn_N_filt,
        'cnn_len_filt': args.cnn_len_filt,
        'cnn_max_pool_len': args.cnn_max_pool_len,
        'cnn_use_laynorm_inp': args.cnn_use_laynorm_inp,
        'cnn_use_batchnorm_inp': args.cnn_use_batchnorm_inp,
        'cnn_use_laynorm': args.cnn_use_laynorm,
        'cnn_use_batchnorm': args.cnn_use_batchnorm,
        'cnn_act': args.cnn_act,
        'cnn_drop': args.cnn_drop,
    }

    DNN1_arch = {
        'fc_lay': args.fc_lay,
        'fc_drop': args.fc_drop,
        'fc_use_batchnorm': args.fc_use_batchnorm,
        'fc_use_laynorm': args.fc_use_laynorm,
        'fc_use_laynorm_inp': args.fc_use_laynorm_inp,
        'fc_use_batchnorm_inp': args.fc_use_batchnorm_inp,
        'fc_act': args.fc_act,
    }

    DNN2_arch = {
        'input_dim': args.fc_lay[-1],
        'fc_lay': args.class_lay,
        'fc_drop': args.class_drop,
        'fc_use_batchnorm': args.class_use_batchnorm,
        'fc_use_laynorm': args.class_use_laynorm,
        'fc_use_laynorm_inp': args.class_use_laynorm_inp,
        'fc_use_batchnorm_inp': args.class_use_batchnorm_inp,
        'fc_act': args.class_act,
    }

    model = SpeakerIDNet(CNN_arch, DNN1_arch, DNN2_arch)
    if args.pt_file != '':
        print("load model from:", args.pt_file)
        checkpoint_load = torch.load(args.pt_file)
        ext = os.path.splitext(args.pt_file)[1]
        if ext == '.pkl':
            model.load_raw_state_dict(checkpoint_load)
        elif ext == '.pickle':
            model.load_state_dict(checkpoint_load)
        elif ext == '.pth':
            load_checkpoint(model, args.pt_file)
        else:
            raise NotImplementedError
    model = model.cuda()
    if args.eval:
        print('only eval the model')
        evaluate(model, test_dataset, cost)
        return
    else:
        print("train the model")
    optimizer = optim.RMSprop(model.parameters(),
                              lr=args.lr,
                              alpha=0.95,
                              eps=1e-8)
    train_dataloader = DataLoader(train_dataset,
                                  args.batch_size,
                                  shuffle=True,
                                  num_workers=8,
                                  pin_memory=True)
    test_dataloader = DataLoader(test_dataset,
                                 1,
                                 shuffle=False,
                                 num_workers=8,
                                 pin_memory=True)
    trainer = ClassifierTrainer(model,
                                train_dataloader,
                                optimizer,
                                cost,
                                batch_process,
                                args.output_dir,
                                0,
                                test_dataloader,
                                eval_every=args.N_eval_epoch,
                                print_every=args.print_every)
    trainer.run(args.N_epochs)
示例#12
0
文件: train.py 项目: coallaoh/mxfont
def train(args, cfg, ddp_gpu=-1):
    cfg.gpu = ddp_gpu
    torch.cuda.set_device(ddp_gpu)
    cudnn.benchmark = True

    logger_path = cfg.work_dir / "log.log"
    logger = Logger.get(file_path=logger_path, level="info", colorize=True)

    image_scale = 0.5
    image_path = cfg.work_dir / "images"
    writer = utils.DiskWriter(image_path, scale=image_scale)
    cfg.tb_freq = -1

    args_str = dump_args(args)
    if is_main_worker(ddp_gpu):
        logger.info("Run Argv:\n> {}".format(" ".join(sys.argv)))
        logger.info("Args:\n{}".format(args_str))
        logger.info("Configs:\n{}".format(cfg.dumps()))

    logger.info("Get dataset ...")

    trn_transform, val_transform = setup_transforms(cfg)

    primals = json.load(open(cfg.primals))
    decomposition = json.load(open(cfg.decomposition))
    n_comps = len(primals)

    trn_dset, trn_loader = get_trn_loader(cfg.dset.train,
                                          primals,
                                          decomposition,
                                          trn_transform,
                                          use_ddp=cfg.use_ddp,
                                          batch_size=cfg.batch_size,
                                          num_workers=cfg.n_workers,
                                          shuffle=True)

    test_dset, test_loader = get_val_loader(cfg.dset.val,
                                            val_transform,
                                            batch_size=cfg.batch_size,
                                            num_workers=cfg.n_workers,
                                            shuffle=False)

    logger.info("Build model ...")
    # generator
    g_kwargs = cfg.get("g_args", {})
    gen = Generator(1, cfg.C, 1, **g_kwargs)
    gen.cuda()
    gen.apply(weights_init(cfg.init))

    d_kwargs = cfg.get("d_args", {})
    disc = disc_builder(cfg.C, trn_dset.n_fonts, trn_dset.n_chars, **d_kwargs)
    disc.cuda()
    disc.apply(weights_init(cfg.init))

    aux_clf = aux_clf_builder(gen.feat_shape["last"], trn_dset.n_fonts,
                              n_comps, **cfg.ac_args)
    aux_clf.cuda()
    aux_clf.apply(weights_init(cfg.init))

    g_optim = optim.Adam(gen.parameters(), lr=cfg.g_lr, betas=cfg.adam_betas)
    d_optim = optim.Adam(disc.parameters(), lr=cfg.d_lr, betas=cfg.adam_betas)
    ac_optim = optim.Adam(aux_clf.parameters(),
                          lr=cfg.ac_lr,
                          betas=cfg.adam_betas)

    st_step = 0
    if cfg.resume:
        st_step, loss = load_checkpoint(cfg.resume, gen, disc, aux_clf,
                                        g_optim, d_optim, ac_optim,
                                        cfg.force_resume)
        logger.info(
            "Resumed checkpoint from {} (Step {}, Loss {:7.3f})".format(
                cfg.resume, st_step, loss))

    evaluator = Evaluator(writer)

    trainer = FactTrainer(gen, disc, g_optim, d_optim, aux_clf, ac_optim,
                          writer, logger, evaluator, test_loader, cfg)

    trainer.train(trn_loader, st_step, cfg.max_iter)