コード例 #1
0
ファイル: mpm_train.py プロジェクト: versey-sherry/MPM
def train_net(net, device, cfg):

    if cfg.eval.imgs is not None:
        train = MPM_Dataset(cfg.train, cfg.dataloader)
        val = MPM_Dataset(cfg.eval, cfg.dataloader)
        n_train = len(train)
        n_val = len(val)
    else:
        dataset = MPM_Dataset(cfg.train, cfg.dataloader)
        n_val = int(len(dataset) * cfg.eval.rate)
        n_train = len(dataset) - n_val
        train, val = random_split(dataset, [n_train, n_val])

    epochs = cfg.train.epochs
    batch_size = cfg.train.batch_size
    train_loader = DataLoader(train,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True)
    val_loader = DataLoader(val,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=8,
                            pin_memory=True,
                            drop_last=True)

    writer = SummaryWriter(log_dir=to_absolute_path('./logs'),
                           comment=f'LR_{cfg.train.lr}_BS_{batch_size}')
    global_step = 0

    optimizer = optim.Adam(net.parameters(), lr=cfg.train.lr)
    # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if net.n_classes > 1 else 'max', patience=2)
    criterion = RMSE_Q_NormLoss(0.8)

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {cfg.train.lr}
        Training size:   {len(train)}
        Validation size: {len(val)}
        Checkpoints:     {cfg.output.save}
        Device:          {device.type}
        Intervals        {cfg.train.itvs}
        Optimizer        {optimizer.__class__.__name__}
        Criterion        {criterion.__class__.__name__}
    ''')

    for epoch in range(epochs):
        net.train()
        epoch_loss = 0
        with tqdm(total=n_train,
                  desc=f'Epoch {epoch + 1}/{epochs}',
                  unit='img') as pbar:
            for batch in train_loader:
                imgs = batch['img']
                mpms_gt = batch['mpm']

                imgs = imgs.to(device=device, dtype=torch.float32)
                mpms_gt = mpms_gt.to(device=device, dtype=torch.float32)

                mpms_pred = net(imgs)
                loss = criterion(mpms_pred, mpms_gt)
                epoch_loss += loss.item()
                writer.add_scalar('Loss/train', loss.item(), global_step)

                pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_value_(net.parameters(), 0.1)
                optimizer.step()

                pbar.update(imgs.shape[0])
                global_step += 1
                if global_step % (n_train // (batch_size)) == 0:
                    for tag, value in net.named_parameters():
                        tag = tag.replace('.', '/')
                        writer.add_histogram('weights/' + tag,
                                             value.data.cpu().numpy(),
                                             global_step)
                        writer.add_histogram('grads/' + tag,
                                             value.grad.data.cpu().numpy(),
                                             global_step)
                    val_loss = eval_net(net, val_loader, device, criterion,
                                        writer, global_step)
                    # scheduler.step(val_score)
                    writer.add_scalar('learning_rate',
                                      optimizer.param_groups[0]['lr'],
                                      global_step)

                    logging.info('Validation loss: {}'.format(val_loss))
                    writer.add_scalar('Loss/test', val_loss, global_step)

        if cfg.output.save:
            try:
                os.mkdir(cfg.output.dir)
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            torch.save(
                net.state_dict(),
                os.path.join(cfg.output.dir, f'CP_epoch{epoch + 1}.pth'))
            logging.info(f'Checkpoint {epoch + 1} saved !')

    writer.close()
コード例 #2
0
                targets = targets.to(torch.device('cuda:0'))

            outputs = net(imgs)
            loss = criterion(outputs, targets)
            epoch_loss += loss[0].item()
            f0.write('{}\n'.format(loss[0].item()))
            N_train = len(loader.ids['train'])
            print('\rTraining...[{0}/{1}] --- MSE : {2:.6f}'.format(
                i * batch_size, N_train, loss[0].item()),
                  end='')

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        loss = epoch_loss / (i + 1)
        print('\nEpoch finished ! Loss: {}'.format(loss))
        f1.write('{}\t{}\n'.format(epoch + 1, loss))

        val_loss = eval_net(net, loader.ids['val'], criterion, gpu,
                            dir_checkpoint, epoch + 1, loader.cell, loader.mpm)
        print('\nvalidation MSE Loss: {}'.format(val_loss))
        f2.write('{}\t{}\n'.format(epoch + 1, val_loss))

        torch.save(net.state_dict(),
                   dir_checkpoint + 'CP{}.pth'.format(epoch + 1))
        if pre_val_loss < val_loss:
            os.remove(dir_checkpoint + 'CP{}.pth'.format(epoch))
        pre_val_loss = val_loss
        print('Checkpoint {} saved !'.format(epoch + 1))