def load_and_setup_model(model_name, parser, checkpoint, amp_run, cpu_run, forward_is_infer=False):
    model_parser = models.parse_model_args(model_name, parser, add_help=False)
    model_args, _ = model_parser.parse_known_args()

    model_config = models.get_model_config(model_name, model_args)
    model = models.get_model(model_name, model_config, cpu_run, forward_is_infer=forward_is_infer)

    if checkpoint is not None:
        if cpu_run:
            state_dict = torch.load(checkpoint, map_location=torch.device('cpu'))['state_dict']
        else:
            state_dict = torch.load(checkpoint)['state_dict']

        if checkpoint_from_distributed(state_dict):
            state_dict = unwrap_distributed(state_dict)

        model.load_state_dict(state_dict)

    if model_name == "WaveGlow":
        model = model.remove_weightnorm(model)

    model.eval()

    if amp_run:
        model, _ = amp.initialize(model, [], opt_level="O3")

    return model
예제 #2
0
def load_checkpoint(model, optimizer, epoch, config, amp_run, filepath,
                    local_rank, resume_from_multiproc):
    checkpoint = torch.load(filepath, map_location='cpu')

    epoch[0] = checkpoint['epoch'] + 1
    device_id = local_rank % torch.cuda.device_count()
    torch.cuda.set_rng_state(checkpoint['cuda_rng_state_all'][device_id])
    torch.random.set_rng_state(checkpoint['random_rng_states_all'][device_id])
    config = checkpoint['config']
    if checkpoint_from_distributed(
            checkpoint['state_dict']) and resume_from_multiproc:
        checkpoint['state_dict'] = unwrap_distributed(checkpoint['state_dict'])
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])

    if amp_run:
        amp.load_state_dict(checkpoint['amp'])

    print(f"Loaded checkpoint from epoch {epoch[0]}")
예제 #3
0
def main():

    parser = argparse.ArgumentParser(description='PyTorch Tacotron 2 Training')
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()

    LOGGER.set_model_name("Tacotron2_PyT")
    LOGGER.set_backends([
        dllg.StdOutBackend(log_file=None,
                           logging_scope=dllg.TRAIN_ITER_SCOPE,
                           iteration_interval=1),
        dllg.JsonBackend(log_file=args.log_file if args.rank == 0 else None,
                         logging_scope=dllg.TRAIN_ITER_SCOPE,
                         iteration_interval=1)
    ])

    LOGGER.timed_block_start("run")
    LOGGER.register_metric(tags.TRAIN_ITERATION_LOSS,
                           metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("iter_time", metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("epoch_time", metric_scope=dllg.EPOCH_SCOPE)
    LOGGER.register_metric("run_time", metric_scope=dllg.RUN_SCOPE)
    LOGGER.register_metric("val_iter_loss", metric_scope=dllg.EPOCH_SCOPE)
    LOGGER.register_metric("train_epoch_items/sec",
                           metric_scope=dllg.EPOCH_SCOPE)
    LOGGER.register_metric("train_epoch_avg_items/sec",
                           metric_scope=dllg.EPOCH_SCOPE)
    LOGGER.register_metric("train_epoch_avg_loss",
                           metric_scope=dllg.EPOCH_SCOPE)

    log_hardware()

    model_name = args.model_name
    parser = models.parse_model_args(model_name, parser)
    parser.parse_args()

    args = parser.parse_args()

    log_args(args)

    torch.backends.cudnn.enabled = args.cudnn_enabled
    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    distributed_run = args.world_size > 1
    if distributed_run:
        init_distributed(args, args.world_size, args.rank, args.group_name)

    LOGGER.log(key=tags.RUN_START)
    run_start_time = time.time()

    model_config = models.get_model_config(model_name, args)
    model = models.get_model(model_name,
                             model_config,
                             to_cuda=True,
                             uniform_initialize_bn_weight=not args.
                             disable_uniform_initialize_bn_weight)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.learning_rate,
                                 weight_decay=args.weight_decay)

    if args.amp_run:
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
        if distributed_run:
            model = DDP(model)

    if args.checkpoint != "":
        checkpoint = torch.load(args.checkpoint)
        state_dict = checkpoint['state_dict']

        if checkpoint_from_distributed(state_dict):
            state_dict = unwrap_distributed(state_dict)

        if args.amp_run:
            amp.load_state_dict(checkpoint['amp'])

        optimizer.load_state_dict(checkpoint['optimizer'])
        model.load_state_dict(state_dict)
        print("Loaded from checkpoint: %s !" % args.checkpoint)

    if not args.amp_run and distributed_run:
        model = DDP(model)

    try:
        sigma = args.sigma
    except AttributeError:
        sigma = None

    criterion = loss_functions.get_loss_function(model_name, sigma)

    try:
        n_frames_per_step = args.n_frames_per_step
    except AttributeError:
        n_frames_per_step = None

    collate_fn = data_functions.get_collate_function(model_name,
                                                     n_frames_per_step)
    trainset = data_functions.get_data_loader(model_name, args.dataset_path,
                                              args.training_files, args)
    train_sampler = DistributedSampler(trainset) if distributed_run else None
    train_loader = DataLoader(trainset,
                              num_workers=1,
                              shuffle=False,
                              sampler=train_sampler,
                              batch_size=args.batch_size,
                              pin_memory=False,
                              drop_last=True,
                              collate_fn=collate_fn)

    valset = data_functions.get_data_loader(model_name, args.dataset_path,
                                            args.validation_files, args)

    batch_to_gpu = data_functions.get_batch_to_gpu(model_name)

    iteration = 0
    model.train()

    LOGGER.log(key=tags.TRAIN_LOOP)

    for epoch in range(args.epochs):
        LOGGER.epoch_start()
        epoch_start_time = time.time()
        LOGGER.log(key=tags.TRAIN_EPOCH_START, value=epoch)

        # used to calculate avg items/sec over epoch
        reduced_num_items_epoch = 0

        # used to calculate avg loss over epoch
        train_epoch_avg_loss = 0.0
        train_epoch_avg_items_per_sec = 0.0
        num_iters = 0

        # if overflow at the last iteration then do not save checkpoint
        overflow = False

        for i, batch in enumerate(train_loader):
            print("Batch: {}/{} epoch {}".format(i, len(train_loader), epoch))
            LOGGER.iteration_start()
            iter_start_time = time.time()
            LOGGER.log(key=tags.TRAIN_ITER_START, value=i)

            start = time.perf_counter()
            adjust_learning_rate(epoch, optimizer, args.learning_rate,
                                 args.anneal_steps, args.anneal_factor)

            model.zero_grad()
            x, y, num_items = batch_to_gpu(batch)

            y_pred = model(x)
            loss = criterion(y_pred, y)

            if distributed_run:
                reduced_loss = reduce_tensor(loss.data, args.world_size).item()
                reduced_num_items = reduce_tensor(num_items.data, 1).item()
            else:
                reduced_loss = loss.item()
                reduced_num_items = num_items.item()
            if np.isnan(reduced_loss):
                raise Exception("loss is NaN")

            LOGGER.log(key=tags.TRAIN_ITERATION_LOSS, value=reduced_loss)

            train_epoch_avg_loss += reduced_loss
            num_iters += 1

            # accumulate number of items processed in this epoch
            reduced_num_items_epoch += reduced_num_items

            if args.amp_run:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    amp.master_params(optimizer), args.grad_clip_thresh)
            else:
                loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), args.grad_clip_thresh)

            optimizer.step()

            iteration += 1

            LOGGER.log(key=tags.TRAIN_ITER_STOP, value=i)

            iter_stop_time = time.time()
            iter_time = iter_stop_time - iter_start_time
            items_per_sec = reduced_num_items / iter_time
            train_epoch_avg_items_per_sec += items_per_sec

            LOGGER.log(key="train_iter_items/sec", value=items_per_sec)
            LOGGER.log(key="iter_time", value=iter_time)
            LOGGER.iteration_stop()

        LOGGER.log(key=tags.TRAIN_EPOCH_STOP, value=epoch)
        epoch_stop_time = time.time()
        epoch_time = epoch_stop_time - epoch_start_time

        LOGGER.log(key="train_epoch_items/sec",
                   value=(reduced_num_items_epoch / epoch_time))
        LOGGER.log(key="train_epoch_avg_items/sec",
                   value=(train_epoch_avg_items_per_sec /
                          num_iters if num_iters > 0 else 0.0))
        LOGGER.log(key="train_epoch_avg_loss",
                   value=(train_epoch_avg_loss /
                          num_iters if num_iters > 0 else 0.0))
        LOGGER.log(key="epoch_time", value=epoch_time)

        LOGGER.log(key=tags.EVAL_START, value=epoch)

        validate(model, criterion, valset, iteration, args.batch_size,
                 args.world_size, collate_fn, distributed_run, args.rank,
                 batch_to_gpu)

        LOGGER.log(key=tags.EVAL_STOP, value=epoch)

        if (epoch % args.epochs_per_checkpoint == 0) and args.rank == 0:
            checkpoint_path = os.path.join(
                args.output_directory,
                "checkpoint_{}_{}".format(model_name, epoch))
            save_checkpoint(model, optimizer, epoch, model_config,
                            checkpoint_path, amp if args.amp_run else None)
            save_sample(
                model_name, model, args.waveglow_checkpoint,
                args.tacotron2_checkpoint, args.phrase_path,
                os.path.join(args.output_directory,
                             "sample_{}_{}.wav".format(model_name, iteration)),
                args.sampling_rate)

        LOGGER.epoch_stop()

    run_stop_time = time.time()
    run_time = run_stop_time - run_start_time
    LOGGER.log(key="run_time", value=run_time)
    LOGGER.log(key=tags.RUN_FINAL)

    print("training time", run_stop_time - run_start_time)

    LOGGER.timed_block_stop("run")

    if args.rank == 0:
        LOGGER.finish()