Пример #1
0
def train(train_data, val_data, fold_idx=None):
    train_data = MyDataset(train_data, train_transform)
    train_loader = DataLoader(train_data,
                              batch_size=config.batch_size,
                              shuffle=True)

    val_data = MyDataset(val_data, val_transform)
    val_loader = DataLoader(val_data,
                            batch_size=config.batch_size,
                            shuffle=False)

    model = Net(model_name).to(device)
    # criterion = nn.CrossEntropyLoss()
    criterion = FocalLoss(0.5)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1)

    if fold_idx is None:
        print('start')
        model_save_path = os.path.join(config.model_path,
                                       '{}.bin'.format(model_name))
    else:
        print('start fold: {}'.format(fold_idx + 1))
        model_save_path = os.path.join(
            config.model_path, '{}_fold{}.bin'.format(model_name, fold_idx))
    if os.path.isfile(model_save_path):
        print('加载之前的训练模型')
        model.load_state_dict(torch.load(model_save_path))

    best_val_acc = 0
    last_improved_epoch = 0
    adjust_lr_num = 0
    for cur_epoch in range(config.epochs_num):
        start_time = int(time.time())
        model.train()
        print('epoch:{}, step:{}'.format(cur_epoch + 1, len(train_loader)))
        cur_step = 0
        for batch_x, batch_y in train_loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)

            optimizer.zero_grad()
            probs = model(batch_x)

            train_loss = criterion(probs, batch_y)
            train_loss.backward()
            optimizer.step()

            cur_step += 1
            if cur_step % config.train_print_step == 0:
                train_acc = accuracy(probs, batch_y)
                msg = 'the current step: {0}/{1}, train loss: {2:>5.2}, train acc: {3:>6.2%}'
                print(
                    msg.format(cur_step, len(train_loader), train_loss.item(),
                               train_acc[0].item()))
        val_loss, val_acc = evaluate(model, val_loader, criterion)
        if val_acc >= best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), model_save_path)
            improved_str = '*'
            last_improved_epoch = cur_epoch
        else:
            improved_str = ''
        msg = 'the current epoch: {0}/{1}, val loss: {2:>5.2}, val acc: {3:>6.2%}, cost: {4}s {5}'
        end_time = int(time.time())
        print(
            msg.format(cur_epoch + 1, config.epochs_num, val_loss, val_acc,
                       end_time - start_time, improved_str))
        if cur_epoch - last_improved_epoch > config.patience_epoch:
            print("No optimization for a long time, adjust lr...")
            scheduler.step()
            last_improved_epoch = cur_epoch  # 加上,不然会连续更新的
            adjust_lr_num += 1
            if adjust_lr_num > config.adjust_lr_num:
                print("No optimization for a long time, auto stopping...")
                break
    del model
    gc.collect()
Пример #2
0
def main():
    global device, cfg
    args = parse_args()
    cfg = Config.from_file(args.config)

    out = cfg.train.out
    if not os.path.exists(out):
        os.makedirs(out)

    # save config and command
    commands = sys.argv
    with open(f'{out}/command.txt', 'w') as f:
        f.write('## Command ################\n\n')
        f.write(f'python {commands[0]} ')
        for command in commands[1:]:
            f.write(command + ' ')
        f.write('\n\n\n')
        f.write('## Args ###################\n\n')
        for name in vars(args):
            f.write(f'{name} = {getattr(args, name)}\n')

    shutil.copy(args.config, f'./{out}')

    # Log
    logdir = os.path.join(out, 'log')
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    writer = SummaryWriter(log_dir=logdir)

    # Set device
    cuda = torch.cuda.is_available()
    if cuda and args.gpu >= 0:
        print('# cuda available! #')
        device = torch.device(f'cuda:{args.gpu}')
    else:
        device = 'cpu'

    # Set models
    VGG = vgg.VGG
    VGG.load_state_dict(torch.load(args.vgg))
    VGG = torch.nn.Sequential(*list(VGG.children())[:31])
    model = Net(VGG)
    model.to(device)

    # Prepare dataset
    content_dataset = FaceDataset(cfg, cfg.train.content_dataset)
    content_loader = torch.utils.data.DataLoader(
        content_dataset,
        batch_size=cfg.train.batchsize,
        shuffle=True,
        num_workers=min(cfg.train.batchsize, 16),
        pin_memory=True,
        drop_last=True)
    style_dataset = FaceDataset(cfg, cfg.train.style_dataset)
    style_loader = torch.utils.data.DataLoader(
        style_dataset,
        batch_size=cfg.train.batchsize,
        sampler=InfiniteSamplerWrapper(style_dataset),
        num_workers=0,
        pin_memory=True,
        drop_last=True)
    style_iter = iter(style_loader)
    print(f'content dataset contains {len(content_dataset)} images.')
    print(f'style dataset contains {len(style_dataset)} images.')

    opt = Adam(model.decoder.parameters(),
               lr=cfg.train.parameters.lr,
               betas=(0.5, 0.999))

    iteration = 0
    batchsize = cfg.train.batchsize
    iterations_per_epoch = len(content_loader)
    epochs = cfg.train.iterations // iterations_per_epoch
    for epoch in range(epochs):
        for i, batch in enumerate(content_loader):
            model.train()

            content_images = Variable(batch).to(device)
            style_images = Variable(next(style_iter)).to(device)

            loss_c, loss_s = model(content_images, style_images)
            loss = cfg.train.parameters.lam_c * loss_c + cfg.train.parameters.lam_s * loss_s

            opt.zero_grad()
            loss.backward()
            opt.step()

            writer.add_scalar('loss_content', loss_c.item(), iteration + 1)
            writer.add_scalar('loss_style', loss_s.item(), iteration + 1)

            lr = poly_lr_scheduler(opt,
                                   cfg.train.parameters.lr,
                                   iteration,
                                   lr_decay_iter=10,
                                   max_iter=cfg.train.iterations)
            iteration += 1

            if iteration % cfg.train.print_interval == 0:
                print(
                    f'Epoch:[{epoch}][{iteration}/{cfg.train.iterations}]  loss content:{loss_c.item():.5f} loss style:{loss_s.item():.5f}'
                )

            if iteration % cfg.train.save_interval == 0:
                if not os.path.exists(os.path.join(out, 'checkpoint')):
                    os.makedirs(os.path.join(out, 'checkpoint'))
                path = os.path.join(out, 'checkpoint',
                                    f'iter_{iteration:04d}.pth.tar')
                state = {
                    'state_dict': model.state_dict(),
                    'opt_state_dict': opt.state_dict(),
                    'iteration': iteration,
                }
                torch.save(state, path)

            if iteration % cfg.train.preview_interval == 0:
                if not os.path.exists(os.path.join(out, 'preview')):
                    os.makedirs(os.path.join(out, 'preview'))
                sample = generate_sample(model, content_images, style_images)
                save_image(
                    sample.data.cpu(),
                    os.path.join(out, 'preview', f'iter_{iteration:04d}.png'))