Exemplo n.º 1
0
    def train_mc(self,
                 dataloader,
                 physics,
                 epochs,
                 lr,
                 ckp_interval,
                 schedule,
                 residual=True,
                 pretrained=None,
                 task='',
                 loss_type='l2',
                 cat=True,
                 report_psnr=False,
                 lr_cos=False):
        save_path = './ckp/{}_mc_{}'.format(get_timestamp(),
                                            'res' if residual else '', task)

        os.makedirs(save_path, exist_ok=True)

        generator = UNet(in_channels=self.in_channels,
                         out_channels=self.out_channels,
                         compact=4,
                         residual=residual,
                         circular_padding=True,
                         cat=cat).to(self.device)

        if pretrained:
            checkpoint = torch.load(pretrained)
            generator.load_state_dict(checkpoint['state_dict'])

        if loss_type == 'l2':
            criterion_mc = torch.nn.MSELoss().to(self.device)
        if loss_type == 'l1':
            criterion_mc = torch.nn.L1Loss().to(self.device)

        optimizer = Adam(generator.parameters(),
                         lr=lr['G'],
                         weight_decay=lr['WD'])

        if report_psnr:
            log = LOG(save_path,
                      filename='training_loss',
                      field_name=['epoch', 'loss_fc', 'psnr', 'mse'])
        else:
            log = LOG(save_path,
                      filename='training_loss',
                      field_name=['epoch', 'loss_fc'])

        for epoch in range(epochs):
            adjust_learning_rate(optimizer, epoch, lr['G'], lr_cos, epochs,
                                 schedule)
            loss = closure_mc(generator, dataloader, physics, optimizer,
                              criterion_mc, self.dtype, self.device,
                              report_psnr)

            log.record(epoch + 1, *loss)

            if report_psnr:
                print('{}\tEpoch[{}/{}]\tmc={:.4e}\tpsnr={:.4f}\tmse={:.4e}'.
                      format(get_timestamp(), epoch, epochs, *loss))
            else:
                print('{}\tEpoch[{}/{}]\tmc={:.4e}'.format(
                    get_timestamp(), epoch, epochs, *loss))

            if epoch % ckp_interval == 0 or epoch + 1 == epochs:
                state = {
                    'epoch': epoch,
                    'state_dict': generator.state_dict(),
                    'optimizer': optimizer.state_dict()
                }
                torch.save(
                    state,
                    os.path.join(save_path, 'ckp_{}.pth.tar'.format(epoch)))
        log.close()
Exemplo n.º 2
0
    def train_ei_adv(self,
                     dataloader,
                     physics,
                     transform,
                     epochs,
                     lr,
                     alpha,
                     ckp_interval,
                     schedule,
                     residual=True,
                     pretrained=None,
                     task='',
                     loss_type='l2',
                     cat=True,
                     report_psnr=False,
                     lr_cos=False):
        save_path = './ckp/{}_ei_adv_{}'.format(get_timestamp(), task)

        os.makedirs(save_path, exist_ok=True)

        generator = UNet(in_channels=self.in_channels,
                         out_channels=self.out_channels,
                         compact=4,
                         residual=residual,
                         circular_padding=True,
                         cat=cat)

        if pretrained:
            checkpoint = torch.load(pretrained)
            generator.load_state_dict(checkpoint['state_dict'])

        discriminator = Discriminator(
            (self.in_channels, self.img_width, self.img_height))

        generator = generator.to(self.device)
        discriminator = discriminator.to(self.device)

        if loss_type == 'l2':
            criterion_mc = torch.nn.MSELoss().to(self.device)
            criterion_ei = torch.nn.MSELoss().to(self.device)
        if loss_type == 'l1':
            criterion_mc = torch.nn.L1Loss().to(self.device)
            criterion_ei = torch.nn.L1Loss().to(self.device)

        criterion_gan = torch.nn.MSELoss().to(self.device)

        optimizer_G = Adam(generator.parameters(),
                           lr=lr['G'],
                           weight_decay=lr['WD'])
        optimizer_D = Adam(discriminator.parameters(),
                           lr=lr['D'],
                           weight_decay=0)

        if report_psnr:
            log = LOG(save_path,
                      filename='training_loss',
                      field_name=[
                          'epoch', 'loss_mc', 'loss_ei', 'loss_g', 'loss_G',
                          'loss_D', 'psnr', 'mse'
                      ])
        else:
            log = LOG(save_path,
                      filename='training_loss',
                      field_name=[
                          'epoch', 'loss_mc', 'loss_ei', 'loss_g', 'loss_G',
                          'loss_D'
                      ])

        for epoch in range(epochs):
            adjust_learning_rate(optimizer_G, epoch, lr['G'], lr_cos, epochs,
                                 schedule)
            adjust_learning_rate(optimizer_D, epoch, lr['D'], lr_cos, epochs,
                                 schedule)

            loss = closure_ei_adv(generator, discriminator, dataloader,
                                  physics, transform, optimizer_G, optimizer_D,
                                  criterion_mc, criterion_ei, criterion_gan,
                                  alpha, self.dtype, self.device, report_psnr)

            log.record(epoch + 1, *loss)

            if report_psnr:
                print(
                    '{}\tEpoch[{}/{}]\tfc={:.4e}\tti={:.4e}\tg={:.4e}\tG={:.4e}\tD={:.4e}\tpsnr={:.4f}\tmse={:.4e}'
                    .format(get_timestamp(), epoch, epochs, *loss))
            else:
                print(
                    '{}\tEpoch[{}/{}]\tfc={:.4e}\tti={:.4e}\tg={:.4e}\tG={:.4e}\tD={:.4e}'
                    .format(get_timestamp(), epoch, epochs, *loss))

            if epoch % ckp_interval == 0 or epoch + 1 == epochs:
                state = {
                    'epoch': epoch,
                    'state_dict_G': generator.state_dict(),
                    'state_dict_D': discriminator.state_dict(),
                    'optimizer_G': optimizer_G.state_dict(),
                    'optimizer_D': optimizer_D.state_dict()
                }
                torch.save(
                    state,
                    os.path.join(save_path, 'ckp_{}.pth.tar'.format(epoch)))
        log.close()