Пример #1
0
def main():
    anchors = [30, 54, 95]
    shuffle = not (args.no_shuffle)
    exp = args.exp
    warm_up_epoch = 3

    # Load and process data

    if args.fold:
        df_train = pd.read_csv(args.data_path +
                               'k_fold/official_train_fold%d.csv' %
                               (args.fold))
        df_val = pd.read_csv(args.data_path +
                             'k_fold/official_val_fold%d.csv' % (args.fold))
    else:
        df_train = pd.read_csv(args.data_path + 'official_train.csv')
        df_val = pd.read_csv(args.data_path + 'official_val.csv')

    train = df_train.image_path.to_list()
    val = df_val.image_path.to_list()
    if exp:
        y_train = df_train.anchor.to_list()
        y_val = df_val.anchor.to_list()
        reg_train_gt = df_train.exp_wind.to_list()
        reg_val_gt = df_val.exp_wind.to_list()
    else:
        y_train = df_train.wind_speed.to_list()
        y_val = df_val.wind_speed.to_list()

    train_transform, val_transform = get_transform(args.image_size)

    train_dataset = WindDataset(image_list=train,
                                target=y_train,
                                exp_target=reg_train_gt if exp else None,
                                transform=train_transform)

    val_dataset = WindDataset(image_list=val,
                              target=y_val,
                              exp_target=reg_val_gt if exp else None,
                              transform=val_transform)

    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=args.batch_size,
                              shuffle=shuffle,
                              num_workers=args.num_workers,
                              drop_last=True)

    val_loader = DataLoader(dataset=val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.num_workers,
                            drop_last=True)

    warm_loader = DataLoader(dataset=train_dataset,
                             batch_size=args.batch_size * 14,
                             shuffle=shuffle,
                             num_workers=args.num_workers,
                             drop_last=True)

    # Load model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    last_epoch = 0

    # model = ResNet50_BN_idea()
    if not exp:
        model = Effnet_Wind_B7()
        # model = Effnet_Wind_B5()
    else:
        model = Effnet_Wind_B5_exp_6()
    # model = ResNetExample()
    # if not exp:
    #     model = Seresnext_Wind()
    # else:
    #     model = Seresnext_Wind_Exp()

    # Optimizer
    if args.opt == 'radam':
        optimizer = RAdam(
            model.parameters(),
            lr=args.lr,
            betas=(0.9, 0.999),
            eps=1e-8,
            weight_decay=args.weight_decay,
        )
    elif args.opt == 'adamw':
        optimizer = AdamW(model.parameters(), args.lr)

    elif args.opt == 'adam':
        optimizer = Adam(model.parameters(),
                         args.lr,
                         weight_decay=args.weight_decay)
    else:
        optimizer = SGD(model.parameters(),
                        args.lr,
                        momentum=0.9,
                        nesterov=True,
                        weight_decay=args.weight_decay)

    if args.weights:
        # model.load_state_dict(torch.load(args.weights))
        last_epoch = extract_number(args.weights)
        try:
            checkpoint = torch.load(args.weights)
            model.load_state_dict(checkpoint['model_state_dict'])
            if checkpoint['pre_opt'] == args.opt:
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                print(optimizer)
        except:
            model.load_state_dict(torch.load(args.weights))
    else:
        model.apply(reset_m_batchnorm)

    model.to(device)

    # Loss function
    if exp:
        criterion = JointLoss2()
    else:
        criterion = RMSELoss()

    # generate log and visualization
    save_path = args.save_path

    log_cache = (args.batch_size, args.image_size, shuffle, exp)

    write_log(args.save_path, model, optimizer, criterion, log_cache)

    plot_dict = {'train': list(), 'val': list()}

    log_train_path = save_path + 'training_log.txt'
    plot_train_path = save_path + 'log.json'

    write_mode = 'w'

    if os.path.exists(log_train_path) and os.path.exists(plot_train_path):
        write_mode = 'a'
        with open(plot_train_path, 'r') as j:
            plot_dict = json.load(j)
            plot_dict['train'] = plot_dict['train'][:last_epoch]
            plot_dict['val'] = plot_dict['val'][:last_epoch]

    # Training
    print('Start warm up')
    model.freeze_except_last()
    for epoch in range(warm_up_epoch):
        warm_up(
            model=model,
            dataloader=warm_loader,
            optimizer=optimizer,
            criterion=criterion,
            device=device,
        )
    model.unfreeze()
    with open(log_train_path, write_mode) as f:
        for epoch in range(1, args.epoch + 1):
            print('Epoch:', epoch + last_epoch)
            f.write('Epoch: %d\n' % (epoch + last_epoch))
            loss = train_epoch(model=model,
                               dataloader=train_loader,
                               optimizer=optimizer,
                               criterion=criterion,
                               device=device,
                               exp=exp)
            RMSE = val_epoch(model=model,
                             dataloader=val_loader,
                             device=device,
                             exp=exp,
                             anchors=anchors)
            if not exp:
                f.write('Training loss: %.4f\n' % (loss))
                f.write('RMSE val: %.4f\n' % (RMSE))
                print('RMSE loss: %.4f' % (loss))
                print('RMSE val: %.4f' % (RMSE))
            else:
                loss, classify, regress = loss
                RMSE, accuracy = RMSE
                f.write('Training loss: %.4f\n' % (loss))
                f.write('Classification loss: %.4f\n' % (classify))
                f.write('Regression loss: %.4f\n' % (regress))
                f.write('Accuracy val: %.4f\n' % (accuracy))
                f.write('RMSE val: %.4f\n' % (RMSE))
                print('Training loss: %.4f' % (loss))
                print('Classification loss: %.4f' % (classify))
                print('Regression loss: %.4f' % (regress))
                print('Accuracy val: %.4f' % (accuracy))
                print('RMSE val: %.4f' % (RMSE))

            # torch.save(model.state_dict(), save_path + 'epoch%d.pth'%(epoch+last_epoch))
            save_name = save_path + 'epoch%d.pth' % (epoch + last_epoch)
            save_pth(save_name, epoch + last_epoch, model, optimizer, args.opt)

            plot_dict['train'].append(loss)
            plot_dict['val'].append(RMSE)
            with open(plot_train_path, 'w') as j:
                json.dump(plot_dict, j)
Пример #2
0
def train(args, cfg):
    device = torch.device('cuda')
    model = ModelWithLoss(cfg).to(device)
    print('------------Model Architecture-------------')
    print(model)

    print('Loading Datasets...')
    data_loader = {}

    if cfg.SOLVER.AUGMENTATION:
        train_transforms = SyntheticTransforms()
    else:
        train_transforms = ToTensor()
        
    if cfg.DATASET.TRACK == 'synthetic':
        train_dataset = SyntheticBurst(ZurichRAW2RGB(cfg.DATASET.TRAIN_SYNTHETIC), crop_sz=cfg.SOLVER.PATCH_SIZE, burst_size=cfg.MODEL.BURST_SIZE, transform=train_transforms)
    elif cfg.DATASET.TRACK == 'real':
        train_dataset = BurstSRDataset(cfg.DATASET.REAL, split='train', crop_sz=cfg.SOLVER.PATCH_SIZE // 8, burst_size=cfg.MODEL.BURST_SIZE)
    sampler = RandomSampler(train_dataset)
    batch_sampler = BatchSampler(sampler=sampler, batch_size=cfg.SOLVER.BATCH_SIZE, drop_last=True)
    batch_sampler = IterationBasedBatchSampler(batch_sampler, num_iterations=cfg.SOLVER.MAX_ITER)
    train_loader = DataLoader(train_dataset, num_workers=args.num_workers, batch_sampler=batch_sampler, pin_memory=True)

    data_loader['train'] = train_loader

    # if args.eval_step != 0:
    #     val_transforms =
    #     val_dataset =
    #     sampler = SequentialSampler(val_dataset)
    #     batch_sampler = BatchSampler(sampler=sampler, batch_size=args.batch_size, drop_last=False)
    #     val_loader = DataLoader(val_dataset, num_workers=args.num_workers, batch_sampler=batch_sampler)

    #     data_loader['val'] = val_loader

    if cfg.SOLVER.OPTIMIZER == 'radam':
        optimizer = RAdam(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.SOLVER.LR)
    elif cfg.SOLVER.OPTIMIZER == 'adabound':
        optimizer = AdaBound(filter(lambda p:p.requires_grad, model.parameters()), lr=cfg.SOLVER.LR, final_lr=cfg.SOLVER.FINAL_LR)
    # optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.SOLVER.LR)
    # scheduler = MultiStepLR(optimizer, cfg.SOLVER.LR_STEP, gamma=0.1)
    scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.LR, cfg.SOLVER.LR_STEP, warmup_factor=cfg.SOLVER.WARMUP_FACTOR, warmup_iters=cfg.SOLVER.WARMUP_ITER)

    if args.resume_iter != 0:
        model_path = os.path.join(cfg.OUTPUT_DIR, 'model', 'iteration_{}.pth'.format(args.resume_iter))
        print(f'Resume from {model_path}')
        model.model.load_state_dict(fix_model_state_dict(torch.load(os.path.join(cfg.OUTPUT_DIR, 'model', 'iteration_{}.pth'.format(args.resume_iter)))))
        if model.flow_refine:
            FR_model_path = os.path.dirname(model_path)[:-5] + "FR_model/" + 'iteration_{}.pth'.format(args.resume_iter)
            model.FR_model.load_state_dict(torch.load(FR_model_path))
        if model.denoise_burst:
            denoise_model_path = os.path.dirname(model_path)[:-5] + "denoise_model/" + 'iteration_{}.pth'.format(args.resume_iter)
            model.denoise_model.load_state_dict(torch.load(denoise_model_path))
        optimizer.load_state_dict(torch.load(os.path.join(cfg.OUTPUT_DIR, 'optimizer', 'iteration_{}.pth'.format(args.resume_iter))))
        scheduler.load_state_dict(torch.load(os.path.join(cfg.OUTPUT_DIR, 'scheduler', 'iteration_{}.pth'.format(args.resume_iter))))
    elif cfg.SOLVER.PRETRAIN_MODEL != '':
        model_path = cfg.SOLVER.PRETRAIN_MODEL
        print(f'load pretrain model from {model_path}')
        model.model.load_state_dict(fix_model_state_dict(torch.load(model_path)))
        if model.flow_refine:
            FR_model_path = os.path.dirname(model_path)[:-5] + "FR_model/" + os.path.basename(cfg.SOLVER.PRETRAIN_MODEL)
            model.FR_model.load_state_dict(torch.load(FR_model_path))
        if model.denoise_burst:
            denoise_model_path = os.path.dirname(model_path)[:-5] + "denoise_model/" + os.path.basename(cfg.SOLVER.PRETRAIN_MODEL)
            model.denoise_model.load_state_dict(torch.load(denoise_model_path))

    if cfg.SOLVER.SYNC_BATCHNORM:
        model = convert_model(model).to(device)
    
    if args.num_gpus > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpus)))

    if not args.debug:
        summary_writer = SummaryWriter(log_dir=cfg.OUTPUT_DIR)
    else:
        summary_writer = None

    do_train(args, cfg, model, optimizer, scheduler, data_loader, device, summary_writer)
Пример #3
0
def train(rank: int, cfg: DictConfig):
    print(OmegaConf.to_yaml(cfg))

    if cfg.train.n_gpu > 1:
        init_process_group(backend=cfg.train.dist_config['dist_backend'],
                           init_method=cfg.train.dist_config['dist_url'],
                           world_size=cfg.train.dist_config['world_size'] *
                           cfg.train.n_gpu,
                           rank=rank)

    device = torch.device(
        'cuda:{:d}'.format(rank) if torch.cuda.is_available() else 'cpu')

    generator = Generator(sum(cfg.model.feature_dims), *cfg.model.cond_dims,
                          **cfg.model.generator).to(device)
    discriminator = Discriminator(**cfg.model.discriminator).to(device)

    if rank == 0:
        print(generator)
        os.makedirs(cfg.train.ckpt_dir, exist_ok=True)
        print("checkpoints directory : ", cfg.train.ckpt_dir)

    if os.path.isdir(cfg.train.ckpt_dir):
        cp_g = scan_checkpoint(cfg.train.ckpt_dir, 'g_')
        cp_do = scan_checkpoint(cfg.train.ckpt_dir, 'd_')

    steps = 1
    if cp_g is None or cp_do is None:
        state_dict_do = None
        last_epoch = -1
    else:
        state_dict_g = load_checkpoint(cp_g, device)
        state_dict_do = load_checkpoint(cp_do, device)
        generator.load_state_dict(state_dict_g['generator'])
        discriminator.load_state_dict(state_dict_do['discriminator'])
        steps = state_dict_do['steps'] + 1
        last_epoch = state_dict_do['epoch']

    if cfg.train.n_gpu > 1:
        generator = DistributedDataParallel(generator,
                                            device_ids=[rank]).to(device)
        discriminator = DistributedDataParallel(discriminator,
                                                device_ids=[rank]).to(device)

    optim_g = RAdam(generator.parameters(), cfg.opt.lr, betas=cfg.opt.betas)
    optim_d = RAdam(discriminator.parameters(),
                    cfg.opt.lr,
                    betas=cfg.opt.betas)

    if state_dict_do is not None:
        optim_g.load_state_dict(state_dict_do['optim_g'])
        optim_d.load_state_dict(state_dict_do['optim_d'])

    scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
        optim_g, gamma=cfg.opt.lr_decay, last_epoch=last_epoch)
    scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
        optim_d, gamma=cfg.opt.lr_decay, last_epoch=last_epoch)

    train_filelist = load_dataset_filelist(cfg.dataset.train_list)
    trainset = FeatureDataset(cfg.dataset, train_filelist, cfg.data)
    train_sampler = DistributedSampler(
        trainset) if cfg.train.n_gpu > 1 else None
    train_loader = DataLoader(trainset,
                              batch_size=cfg.train.batch_size,
                              num_workers=cfg.train.num_workers,
                              shuffle=True,
                              sampler=train_sampler,
                              pin_memory=True,
                              drop_last=True)

    if rank == 0:
        val_filelist = load_dataset_filelist(cfg.dataset.test_list)
        valset = FeatureDataset(cfg.dataset,
                                val_filelist,
                                cfg.data,
                                segmented=False)
        val_loader = DataLoader(valset,
                                batch_size=1,
                                num_workers=cfg.train.num_workers,
                                shuffle=False,
                                sampler=train_sampler,
                                pin_memory=True)

        sw = SummaryWriter(os.path.join(cfg.train.ckpt_dir, 'logs'))

    generator.train()
    discriminator.train()
    for epoch in range(max(0, last_epoch), cfg.train.epochs):
        if rank == 0:
            start = time.time()
            print("Epoch: {}".format(epoch + 1))

        if cfg.train.n_gpu > 1:
            train_sampler.set_epoch(epoch)

        for y, x_noised_features, x_noised_cond in train_loader:
            if rank == 0:
                start_b = time.time()

            y = y.to(device, non_blocking=True)
            x_noised_features = x_noised_features.transpose(1, 2).to(
                device, non_blocking=True)
            x_noised_cond = x_noised_cond.to(device, non_blocking=True)
            z1 = torch.randn(cfg.train.batch_size,
                             cfg.model.cond_dims[1],
                             device=device)
            z2 = torch.randn(cfg.train.batch_size,
                             cfg.model.cond_dims[1],
                             device=device)

            y_hat1 = generator(x_noised_features, x_noised_cond, z=z1)
            y_hat2 = generator(x_noised_features, x_noised_cond, z=z2)

            # Discriminator
            real_scores, fake_scores = discriminator(y), discriminator(
                y_hat1.detach())
            d_loss = discriminator_loss(real_scores, fake_scores)

            optim_d.zero_grad()
            d_loss.backward(retain_graph=True)
            optim_d.step()

            # Generator
            g_stft_loss = criterion(y, y_hat1) + criterion(
                y, y_hat2) - criterion(y_hat1, y_hat2)
            g_adv_loss = adversarial_loss(fake_scores)
            g_loss = g_adv_loss + g_stft_loss

            optim_g.zero_grad()
            g_loss.backward()
            optim_g.step()

            if rank == 0:
                # STDOUT logging
                if steps % cfg.train.stdout_interval == 0:
                    with torch.no_grad():
                        print(
                            'Steps : {:d}, Gen Loss Total : {:4.3f}, STFT Error : {:4.3f}, s/b : {:4.3f}'
                            .format(steps, g_loss, g_stft_loss,
                                    time.time() - start_b))

                # checkpointing
                if steps % cfg.train.checkpoint_interval == 0:
                    ckpt_dir = "{}/g_{:08d}".format(cfg.train.ckpt_dir, steps)
                    save_checkpoint(
                        ckpt_dir, {
                            'generator':
                            (generator.module if cfg.train.n_gpu > 1 else
                             generator).state_dict()
                        })
                    ckpt_dir = "{}/do_{:08d}".format(cfg.train.ckpt_dir, steps)
                    save_checkpoint(
                        ckpt_dir, {
                            'discriminator':
                            (discriminator.module if cfg.train.n_gpu > 1 else
                             discriminator).state_dict(),
                            'optim_g':
                            optim_g.state_dict(),
                            'optim_d':
                            optim_d.state_dict(),
                            'steps':
                            steps,
                            'epoch':
                            epoch
                        })

                # Tensorboard summary logging
                if steps % cfg.train.summary_interval == 0:
                    sw.add_scalar("training/gen_loss_total", g_loss, steps)
                    sw.add_scalar("training/gen_stft_error", g_stft_loss,
                                  steps)

                # Validation
                if steps % cfg.train.validation_interval == 0:
                    generator.eval()
                    torch.cuda.empty_cache()
                    val_err_tot = 0
                    with torch.no_grad():
                        for j, (y, x_noised_features,
                                x_noised_cond) in enumerate(val_loader):
                            y_hat = generator(
                                x_noised_features.transpose(1, 2).to(device),
                                x_noised_cond.to(device))
                            val_err_tot += criterion(y, y_hat).item()

                            if j <= 4:
                                # sw.add_audio('noised/y_noised_{}'.format(j), y_noised[0], steps, cfg.data.target_sample_rate)
                                sw.add_audio('generated/y_hat_{}'.format(j),
                                             y_hat[0], steps,
                                             cfg.data.sample_rate)
                                sw.add_audio('gt/y_{}'.format(j), y[0], steps,
                                             cfg.data.sample_rate)

                        val_err = val_err_tot / (j + 1)
                        sw.add_scalar("validation/stft_error", val_err, steps)

                    generator.train()

            steps += 1

        scheduler_g.step()
        scheduler_d.step()

        if rank == 0:
            print('Time taken for epoch {} is {} sec\n'.format(
                epoch + 1, int(time.time() - start)))