Beispiel #1
0
def train(args, pt_dir, chkpt_path, writer, logger, hp, hp_str):
    # If the train-all flag is set, train all tiers one-by-one. Otherwise, just train one tier.
    # if the train-all flag is set, start training at the tier specified, and work down from there.
    if args.train_all:
        num_tiers = hp['model']['tier']
        cur_tier = args.tier
        main_name = args.name
        i = 1
        while True:
            print("Training tier #%d" % cur_tier)
            if cur_tier == 1:
                args.tts = True
            else:
                args.tts = False
            args.tier = cur_tier
            print("Beginning training tier %d, for %s with tts=%s. This is iteration #%d." % (args.tier, args.name, str(args.tts), i))
            trainloader = create_dataloader(hp, args, train=True)
            testloader = create_dataloader(hp, args, train=False)
            train_helper(args, pt_dir, chkpt_path, trainloader, testloader, writer, logger, hp, hp_str)
            cur_tier -= 1
            if cur_tier == 0:
                print("All tiers were trained an epoch! Starting again at top tier.")
                cur_tier = num_tiers
            print('')
            i += 1
    else:
        print("Training a specific tier")
        trainloader = create_dataloader(hp, args, train=True)
        testloader = create_dataloader(hp, args, train=False)
        train_helper(args, pt_dir, chkpt_path, trainloader, testloader, writer, logger, hp, hp_str)
Beispiel #2
0
    if not os.path.isdir(log_dir):
        os.mkdir(log_dir)

    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s - %(levelname)s - %(message)s',
                        handlers=[
                            logging.FileHandler(
                                os.path.join(
                                    log_dir,
                                    '%s-%d.log' % (args.name, time.time()))),
                            logging.StreamHandler()
                        ])
    logger = logging.getLogger()

    writer = MyWriter(hp, log_dir)

    assert hp.data.path != '', \
        'hp.data.path cannot be empty: please fill out your dataset\'s path in configuration yaml file.'

    if args.bucket:
        preload_dataset(args.bucket, hp.data.path, args.sample_threshold)

    trainloader = create_dataloader(hp, args, train=True)
    testloader = create_dataloader(hp, args, train=False)

    experiment = Experiment(api_key=args.comet_key, project_name='MelNet')
    experiment.log_parameters(hp)

    train(args, pt_dir, args.checkpoint_path, trainloader, testloader, writer,
          logger, hp, hp_str, experiment)
Beispiel #3
0
def get_testloader(hp, args):
    return create_dataloader(hp, args, train=False)
Beispiel #4
0
def get_testloader(hp, args):
    testloader = create_dataloader(hp, args, train=False)