예제 #1
0
파일: train.py 프로젝트: xidulu/ResNet-VAE
def train_model(loss, batch_size, num_epochs, learning_rate):
    enc = Encoder(**cifar_config().encoder_arc).to(device)
    dec = Decoder(**cifar_config().decoder_arc).to(device)
    model = [enc, dec]
    gd = optim.Adam(chain(*[
        x.parameters() for x in model
        if (isinstance(x, nn.Module) or isinstance(x, nn.Parameter))
    ]),
                    learning_rate,
                    weight_decay=1e-5)
    train_loader, test_loader = get_cifar(batch_size=batch_size,
                                          num_workers=32)
    train_losses = []
    test_results = []
    for cnt in range(num_epochs):
        for i, (batch, _) in enumerate(train_loader):
            total = len(train_loader)
            gd.zero_grad()
            batch = batch.to(device)
            loss_value, _ = loss(batch, enc, dec)
            loss_value.backward()
            train_losses.append(loss_value.item())
            if (i + 1) % 10 == 0:
                print('\rTrain loss:',
                      train_losses[-1],
                      'Batch',
                      i + 1,
                      'of',
                      total,
                      ' ' * 10,
                      end='',
                      flush=True)
            gd.step()
        test_elbo = 0.
        test_mse = 0.
        with torch.autograd.no_grad():
            for i, (batch, _) in enumerate(test_loader):
                batch = batch.to(device)
                batch_loss, recon = loss(batch, enc, dec)
                test_mse += (torch.nn.MSELoss()(recon, batch) -
                             test_mse) / (i + 1)
                test_elbo += (batch_loss - test_elbo) / (i + 1)
        print('\nTest elbo after at epoch {}: {}'.format(cnt, test_elbo))
        print('Test mse after at epoch {}: {}'.format(cnt, test_mse))
        test_results.append((test_elbo, test_mse))

    enc.cpu()
    dec.cpu()
    torch.save(enc.state_dict(), "./ckpt/enc.pt")
    torch.save(dec.state_dict(), "./ckpt/dec.pt")
    with open('./log/log.txt', 'w') as f:
        for item in test_results:
            f.write("%s\n" % float(item))
예제 #2
0
                # timer   # ta
                time_str = self.calculate_remaining(start_time, time.time(),
                                                    iter / batch_per_epoch)
                # save model   # ta
                self.save(iter)

                self.logger.write(disp_str)
                sys.stdout.write(disp_str)
                sys.stdout.write(time_str)  # ta
                sys.stdout.flush()

            iter += 1
            self.iter_cnt += 1


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='model_tester.py')
    parser.add_argument('-suffix',
                        default='run0',
                        type=str,
                        help="Suffix added to the save images.")
    parser.add_argument('-r',
                        default='',
                        type=str,
                        help="Suffix added to the save images.")

    args = parser.parse_args()

    tester = Tester(config.cifar_config(), args)
    tester.test()
예제 #3
0
                                                    iter / batch_per_epoch)
                # save model   # ta
                self.save(iter)

                self.logger.write(disp_str)
                sys.stdout.write(disp_str)
                sys.stdout.write(time_str)  # ta
                sys.stdout.flush()

            iter += 1
            self.iter_cnt += 1


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='cifar_trainer.py')
    cc = config.cifar_config()
    parser.add_argument('-suffix',
                        default='run0',
                        type=str,
                        help="Suffix added to the save images.")
    parser.add_argument('-r',
                        default='',
                        type=str,
                        help="Suffix added to the save images.")
    parser.add_argument('-max_epochs',
                        default=cc.max_epochs,
                        type=int,
                        help="max epoches")
    parser.add_argument('-ld',
                        '--size_labeled_data',
                        default=cc.size_labeled_data,
예제 #4
0
                disp_str = '#{}-{}\ttrain: {:.4f}, {:.2f}% | dev: {:.4f}, {:.2f}%'.format(
                    int(epoch), iter, train_loss, train_accuracy * 100,
                    dev_loss, dev_accuracy * 100)
                for k, v in monitor.items():
                    disp_str += ' | {}: {:.4f}'.format(k,
                                                       v / config.eval_period)
                disp_str += '\n'

                monitor = OrderedDict()

                self.logger.write(disp_str)
                self.logger.flush()
                sys.stdout.write(disp_str)
                sys.stdout.flush()

            iter += 1
            self.iter_cnt += 1


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='cifar_trainer.py')
    parser.add_argument('--suffix',
                        default='run0',
                        type=str,
                        help="Suffix added to the save images.")

    args = parser.parse_args()

    trainer = Trainer(config.cifar_config(), args)
    trainer.train()