コード例 #1
0
def run(args):
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    logdir = args.logdir
    checkpoint = args.checkpoint
    start_epoch = 0
    best_loss = float('inf')
    epochs_since_improvement = 0

    # Initialize / load checkpoint
    if checkpoint is None:
        # model
        model = Tacotron2(config)
        # optimizer
        optimizer = Tacotron2Optimizer(
            optim.Adam(model.parameters(),
                       lr=args.lr,
                       weight_decay=args.l2,
                       betas=(0.9, 0.999),
                       eps=1e-6))

    else:
        start_epoch, epochs_since_improvement, model, optimizer, best_loss = load_checkpoint(
            logdir, checkpoint)

    logger = Logger(config.logdir, config.experiment, 'tacotron2')

    # Move to GPU, if available
    model = model.to(config.device)

    criterion = Tacotron2Loss()

    # Custom dataloaders
    train_dataset = Text2MelDataset(config.train_files, config)
    train_loader = Text2MelDataLoader(train_dataset,
                                      config,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)
    valid_dataset = Text2MelDataset(config.valid_files, config)
    valid_loader = Text2MelDataLoader(valid_dataset,
                                      config,
                                      shuffle=False,
                                      num_workers=args.num_workers,
                                      pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, args.epochs):
        # One epoch's training
        train_loss = train(train_loader=train_loader,
                           model=model,
                           optimizer=optimizer,
                           criterion=criterion,
                           epoch=epoch,
                           logger=logger)

        lr = optimizer.lr
        print('\nLearning rate: {}'.format(lr))
        step_num = optimizer.step_num
        print('Step num: {}\n'.format(step_num))

        scalar_dict = {'train_epoch_loss': train_loss, 'learning_rate': lr}
        logger.log_epoch('train', epoch, scalar_dict=scalar_dict)

        # One epoch's validation
        valid_loss = valid(valid_loader=valid_loader,
                           model=model,
                           criterion=criterion,
                           logger=logger)

        # Check if there was an improvement
        is_best = valid_loss < best_loss
        best_loss = min(valid_loss, best_loss)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: {}\n".format(
                epochs_since_improvement))
        else:
            epochs_since_improvement = 0

        scalar_dict = {'valid_epoch_loss': valid_loss}
        logger.log_epoch('valid', epoch, scalar_dict=scalar_dict)

        # Save checkpoint
        if epoch % args.save_freq == 0:
            save_checkpoint(logdir, epoch, epochs_since_improvement, model,
                            optimizer, best_loss, is_best)
コード例 #2
0
    if args.mode == 'simple':
        dummy = torch.ones((1, 3, args.height, args.width))
        print(dummy.dtype)
        fps = []
        for i in range(0, args.times):
            fps.append(
                speed_evaluate_simple(net=net,
                                      device=device,
                                      dummy=dummy,
                                      num=300))
        print('GPU FPS: {: .2f}'.format(max(fps)))
    elif args.mode == 'real':
        if cfg['test']['checkpoint'] is not None:
            load_checkpoint(net=net,
                            optimizer=None,
                            lr_scheduler=None,
                            filename=cfg['test']['checkpoint'])
        val_loader = init_dataset(cfg['dataset'], cfg['test_augmentations'],
                                  (args.height, args.width))
        fps = []
        gpu_fps = []
        for i in range(0, args.times):
            fps_item, gpu_fps_item = speed_evaluate_real(net=net,
                                                         device=device,
                                                         loader=val_loader,
                                                         num=300)
            fps.append(fps_item)
            gpu_fps.append(gpu_fps_item)
        print('Real FPS: {: .2f}'.format(max(fps)))
        print('GPU FPS: {: .2f}'.format(max(gpu_fps)))
    else:
コード例 #3
0
    cfg_with_trace_arg = append_trace_arg(cfg['model'].copy(), trace_arg)
    net = MODELS.from_dict(cfg_with_trace_arg)

    # Move to device (simple single card)
    device = torch.device('cpu')
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    print(device)
    net.to(device)
    net_without_tracing.to(device)

    # Load weights
    if cfg['test']['checkpoint'] is not None:
        load_checkpoint(net=net,
                        optimizer=None,
                        lr_scheduler=None,
                        filename=cfg['test']['checkpoint'],
                        strict=False)
        load_checkpoint(net=net_without_tracing,
                        optimizer=None,
                        lr_scheduler=None,
                        filename=cfg['test']['checkpoint'])
    else:
        raise ValueError('Must provide a weight file by --checkpoint')

    # Set dummy for precision matching
    torch.manual_seed(7)
    dummy = torch.randn(1,
                        3,
                        args.height,
                        args.width,
コード例 #4
0
     test_loader = init(batch_size_labeled=args.batch_size_labeled,
                        batch_size_pseudo=args.batch_size_pseudo,
                        state=3,
                        split=None,
                        valtiny=args.valtiny,
                        no_aug=args.no_aug,
                        input_sizes=input_sizes,
                        data_set=args.dataset,
                        sets_id=args.sets_id,
                        mean=mean,
                        std=std,
                        keep_scale=keep_scale,
                        reverse_channels=reverse_channels)
     load_checkpoint(net=net,
                     optimizer=None,
                     lr_scheduler=None,
                     is_mixed_precision=args.mixed_precision,
                     filename=args.continue_from)
     test_one_set(loader=test_loader,
                  device=device,
                  net=net,
                  categories=categories,
                  num_classes=num_classes,
                  output_size=input_sizes[2],
                  is_mixed_precision=args.mixed_precision)
 else:
     x = 0
     criterion = DynamicMutualLoss(gamma1=args.gamma1,
                                   gamma2=args.gamma2,
                                   ignore_index=255)
     writer = SummaryWriter('logs/' + exp_name)
コード例 #5
0
ファイル: train.py プロジェクト: ptran1203/pytorch-animeGAN
def main(args):
    check_params(args)

    print("Init models...")

    G = Generator(args.dataset).cuda()
    D = Discriminator(args).cuda()

    loss_tracker = LossSummary()

    loss_fn = AnimeGanLoss(args)

    # Create DataLoader
    data_loader = DataLoader(
        AnimeDataSet(args),
        batch_size=args.batch_size,
        num_workers=cpu_count(),
        pin_memory=True,
        shuffle=True,
        collate_fn=collate_fn,
    )

    optimizer_g = optim.Adam(G.parameters(), lr=args.lr_g, betas=(0.5, 0.999))
    optimizer_d = optim.Adam(D.parameters(), lr=args.lr_d, betas=(0.5, 0.999))

    start_e = 0
    if args.resume == 'GD':
        # Load G and D
        try:
            start_e = load_checkpoint(G, args.checkpoint_dir)
            print("G weight loaded")
            load_checkpoint(D, args.checkpoint_dir)
            print("D weight loaded")
        except Exception as e:
            print('Could not load checkpoint, train from scratch', e)
    elif args.resume == 'G':
        # Load G only
        try:
            start_e = load_checkpoint(G, args.checkpoint_dir, posfix='_init')
        except Exception as e:
            print('Could not load G init checkpoint, train from scratch', e)

    for e in range(start_e, args.epochs):
        print(f"Epoch {e}/{args.epochs}")
        bar = tqdm(data_loader)
        G.train()

        init_losses = []

        if e < args.init_epochs:
            # Train with content loss only
            set_lr(optimizer_g, args.init_lr)
            for img, *_ in bar:
                img = img.cuda()

                optimizer_g.zero_grad()

                fake_img = G(img)
                loss = loss_fn.content_loss_vgg(img, fake_img)
                loss.backward()
                optimizer_g.step()

                init_losses.append(loss.cpu().detach().numpy())
                avg_content_loss = sum(init_losses) / len(init_losses)
                bar.set_description(
                    f'[Init Training G] content loss: {avg_content_loss:2f}')

            set_lr(optimizer_g, args.lr_g)
            save_checkpoint(G, optimizer_g, e, args, posfix='_init')
            save_samples(G, data_loader, args, subname='initg')
            continue

        loss_tracker.reset()
        for img, anime, anime_gray, anime_smt_gray in bar:
            # To cuda
            img = img.cuda()
            anime = anime.cuda()
            anime_gray = anime_gray.cuda()
            anime_smt_gray = anime_smt_gray.cuda()

            # ---------------- TRAIN D ---------------- #
            optimizer_d.zero_grad()
            fake_img = G(img).detach()

            # Add some Gaussian noise to images before feeding to D
            if args.d_noise:
                fake_img += gaussian_noise()
                anime += gaussian_noise()
                anime_gray += gaussian_noise()
                anime_smt_gray += gaussian_noise()

            fake_d = D(fake_img)
            real_anime_d = D(anime)
            real_anime_gray_d = D(anime_gray)
            real_anime_smt_gray_d = D(anime_smt_gray)

            loss_d = loss_fn.compute_loss_D(fake_d, real_anime_d,
                                            real_anime_gray_d,
                                            real_anime_smt_gray_d)

            loss_d.backward()
            optimizer_d.step()

            loss_tracker.update_loss_D(loss_d)

            # ---------------- TRAIN G ---------------- #
            optimizer_g.zero_grad()

            fake_img = G(img)
            fake_d = D(fake_img)

            adv_loss, con_loss, gra_loss, col_loss = loss_fn.compute_loss_G(
                fake_img, img, fake_d, anime_gray)

            loss_g = adv_loss + con_loss + gra_loss + col_loss

            loss_g.backward()
            optimizer_g.step()

            loss_tracker.update_loss_G(adv_loss, gra_loss, col_loss, con_loss)

            avg_adv, avg_gram, avg_color, avg_content = loss_tracker.avg_loss_G(
            )
            avg_adv_d = loss_tracker.avg_loss_D()
            bar.set_description(
                f'loss G: adv {avg_adv:2f} con {avg_content:2f} gram {avg_gram:2f} color {avg_color:2f} / loss D: {avg_adv_d:2f}'
            )

        if e % args.save_interval == 0:
            save_checkpoint(G, optimizer_g, e, args)
            save_checkpoint(D, optimizer_d, e, args)
            save_samples(G, data_loader, args)