예제 #1
0
def train():
    ###################
    #    Load Data    #
    ###################

    dataset = PokemonDataset(add_mirrored=True)
    data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=SHUFFLE)

    ###################
    #  Set Up Models  #
    ###################

    auto_encoder = models.AutoEncoder()

    optimizer = optim.Adam(auto_encoder.parameters(), lr=LEARNING_RATE)
    criterion = torch.nn.MSELoss()

    ###################
    #    Training     #
    ###################

    auto_encoder.train()
    log = Logger('train.log', TrainingLogTemplate())

    for epoch in range(EPOCHS):
        epoch_loss = 0
        ctx = {}
        for step, (x, _) in enumerate(data_loader):
            # Train
            y = auto_encoder(x)
            optimizer.zero_grad()
            loss = criterion(y, x)
            loss.backward()
            optimizer.step()

            # Display status
            epoch_loss += loss.item()
            ctx = {
                'epoch': epoch + 1,
                'epochs': EPOCHS,
                'step': step + 1,
                'loss': epoch_loss / (step + 1),
                'data_len': len(data_loader),
            }
            log.write(ctx, overwrite=True)
        log.write(ctx)
    log.close()

    ###################
    #   Save Models   #
    ###################

    auto_encoder.save_states()
예제 #2
0
def test():
    ###################
    #    Load Data    #
    ###################

    dataset = PokemonDataset()
    data_loader = DataLoader(dataset)

    ###################
    #  Set Up Models  #
    ###################

    auto_encoder = models.AutoEncoder()
    auto_encoder.load_states()

    ###################
    #     Testing     #
    ###################

    auto_encoder.eval()
    log = Logger('test.log', TestingLogTemplate())

    total_loss = 0
    for step, (x, _) in enumerate(data_loader):
        # Train
        y = auto_encoder(x)
        loss = F.mse_loss(y, x)

        # Display status
        total_loss += loss.item()
        ctx = {
            'step': step + 1,
            'loss': total_loss / (step + 1),
            'data_len': len(data_loader),
        }
        log.write(ctx, overwrite=True)
    log.close()

    ###################
    #     Visuals     #
    ###################

    # Get inputs and outputs for nine pokemon
    samples = []
    for x, t in dataset[:10]:
        y = auto_encoder(x.unsqueeze(0))
        samples.append((x.squeeze(), y.detach().squeeze()))

    # Display inputs and outputs
    visualize_input_output(samples, save=True)
예제 #3
0
def main(args):
    def get_lr(optimizer):
        for param_group in optimizer.param_groups:
            return param_group['lr']

    def log(phase):
        writer.add_scalar(f'{phase}_loss', loss.item(), global_step)

        if args.display is not None and i % args.display == 0:
            recon = torch.cat((reverse_augment(x[0]), reverse_augment(x_[0])),
                              dim=2)
            writer.add_image(f'{phase}_recon', recon, global_step)

            if args.display:
                view_in.render(recon)

            if args.model_type != 'fc':
                latent = make_grid(z[0].unsqueeze(1), 4, 4)
                writer.add_image(f'{phase}_latent', latent.squeeze(0),
                                 global_step)
                if args.display:
                    view_z.render(latent)

    def nop(x):
        return x

    def flatten(x):
        return x.flatten(start_dim=1)

    def reverse_flatten(x):
        return x.reshape(1, 28, 28)

    torch.cuda.set_device(args.device)
    """ reproducibility """
    if args.seed is not None:
        torch.manual_seed(args.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        np.random.seed(args.seed)
    """ variables """
    best_loss = 100.0
    run_dir = f'data/models/autoencoders/{args.dataset_name}/{args.model_name}/run_{args.run_id}'
    writer = SummaryWriter(log_dir=run_dir)
    global_step = 0
    """ data """
    datapack = package.datasets[args.dataset_name]
    train, test = datapack.make(args.dataset_train_len,
                                args.dataset_test_len,
                                data_root=args.dataroot)
    train_l = DataLoader(train,
                         batch_size=args.batchsize,
                         shuffle=True,
                         drop_last=True,
                         pin_memory=True)
    test_l = DataLoader(test,
                        batch_size=args.batchsize,
                        shuffle=True,
                        drop_last=True,
                        pin_memory=True)
    """ model """
    #encoder, meta = mnn.make_layers(args.model_encoder, type=args.model_type, meta=LayerMetaData(datapack.shape))
    #decoder, meta = mnn.make_layers(args.model_decoder, type=args.model_type, meta=meta)

    encoder, shape = make_layers(cfg=args.model_encoder,
                                 type=args.model_type,
                                 input_shape=datapack.shape)
    decoder, shape = make_layers(cfg=args.model_decoder,
                                 type=args.model_type,
                                 input_shape=shape[-1])

    auto_encoder = models.AutoEncoder(encoder, decoder).to(args.device)
    print(auto_encoder)
    augment = flatten if args.model_type == 'fc' else nop
    reverse_augment = reverse_flatten if args.model_type == 'fc' else nop

    if args.load is not None:
        auto_encoder.load_state_dict(torch.load(args.load))
    """ optimizer """
    optim, scheduler = config.get_optim(args, auto_encoder.parameters())
    """ apex mixed precision """
    # if args.device != 'cpu':
    #     model, optimizer = amp.initialize(auto_encoder, optim, opt_level=args.opt_level)
    """ loss function """
    criterion = nn.MSELoss()

    for epoch in range(1, args.epochs + 1):
        """ training """
        batch = tqdm(train_l, total=len(train) // args.batchsize)
        for i, (x, _) in enumerate(batch):
            x = augment(x).to(args.device)

            optim.zero_grad()
            z, x_ = auto_encoder(x)
            loss = criterion(x_, x)
            if not args.demo:
                loss.backward()
                optim.step()

            batch.set_description(
                f'Epoch: {epoch} {args.optim_class} LR: {get_lr(optim)} Train Loss: {loss.item()}'
            )

            log('train')

            if i % args.checkpoint_freq == 0 and args.demo == 0:
                torch.save(auto_encoder.state_dict(), run_dir + '/checkpoint')

            global_step += 1
        """ test  """
        with torch.no_grad():
            ll = 0.0
            batch = tqdm(test_l, total=len(test) // args.batchsize)
            for i, (images, _) in enumerate(batch):
                x = augment(images).to(args.device)

                z, x_ = auto_encoder(x)
                loss = criterion(x_, x)

                ll += loss.item()
                ave_loss = ll / (i + 1)
                batch.set_description(f'Epoch: {epoch} Test Loss: {ave_loss}')

                log('test')

                global_step += 1
        """ check improvement """
        scheduler.step(ave_loss)

        best_loss = ave_loss if ave_loss <= best_loss else best_loss
        print(
            f'{Fore.CYAN}ave loss: {ave_loss} {Fore.LIGHTBLUE_EX}best loss: {best_loss} {Style.RESET_ALL}'
        )
        """ save if model improved """
        if ave_loss <= best_loss and not args.demo:
            torch.save(auto_encoder.state_dict(), run_dir + '/best')

    return best_loss
예제 #4
0
def main():
    global args
    args = parser.parse_args()
    traindir = os.path.join(args.data, 'RealChallengeFree/train')
    testdir = os.path.join(args.data, 'RealChallengeFree/Test')
    train_dataset = utils.CURETSRDataset(
        traindir,
        transforms.Compose([
            transforms.Resize([28, 28]),
            transforms.ToTensor(), utils.l2normalize, utils.standardization
        ]))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=True)
    test_dataset = utils.CURETSRDataset(
        testdir,
        transforms.Compose([
            transforms.Resize([28, 28]),
            transforms.ToTensor(), utils.l2normalize, utils.standardization
        ]))
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    model = models.AutoEncoder()
    model = torch.nn.DataParallel(model).cuda()
    print("=> creating model %s " % model.__class__.__name__)
    criterion = nn.MSELoss().cuda()

    savedir = 'AutoEncoder'
    checkpointdir = os.path.join('./checkpoints', savedir)
    os.makedirs(checkpointdir, exist_ok=True)
    print('log directory: %s' % os.path.join('./logs', savedir))
    print('checkpoints directory: %s' % checkpointdir)
    logger = Logger(os.path.join('./logs/', savedir))
    if args.evaluate:
        print("=> loading checkpoint ")
        checkpoint = torch.load(
            os.path.join(checkpointdir, 'model_best.pth.tar'))
        model.load_state_dict(checkpoint['AE_state_dict'], strict=False)
        modelCNN = models.Net()
        modelCNN = torch.nn.DataParallel(modelCNN).cuda()
        checkpoint2 = torch.load('./checkpoints/CNN_iter/model_best.pth.tar')
        modelCNN.load_state_dict(checkpoint2['state_dict'], strict=False)
        evaluate(test_loader, model, modelCNN, criterion)
        return
    optimizer = torch.optim.Adam(model.parameters(),
                                 args.lr,
                                 weight_decay=args.weight_decay)
    cudnn.benchmark = True

    timestart = time.time()

    if args.finetune:
        print("=> loading checkpoint ")
        checkpoint = torch.load(
            os.path.join(checkpointdir, 'model_best.pth.tar'))
        model.load_state_dict(checkpoint['AE_state_dict'], strict=False)
        optimizer.load_state_dict(checkpoint['optimizer'])

    best_loss = 10e10
    # train_accs = []
    # test_accs = []
    loss_epochs = []

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)
        print('\n*** Start Training *** \n')
        loss_train = train(train_loader, test_loader, model, criterion,
                           optimizer, epoch)
        print(loss_train)
        loss_epochs.append(loss_train)
        is_best = loss_train < best_loss
        print(best_loss)
        best_loss = min(loss_train, best_loss)
        info = {
            'Loss': loss_train
            # 'Testing Accuracy': test_prec1
        }
        # if not debug:
        for tag, value in info.items():
            logger.scalar_summary(tag, value, epoch + 1)
        if is_best:
            best_epoch = epoch + 1
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'AE_state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }, is_best, checkpointdir)
    generate_plots(range(args.start_epoch, args.epochs), loss_epochs)
    print('Best epoch: ', best_epoch)
    print('Total processing time: %.4f' % (time.time() - timestart))
    print('Best loss:', best_loss)
예제 #5
0
    sample_fn = lambda x: torch.round(torch.sigmoid(x))
    x_dim = 784
if args.dataset in ['fmnist', 'svhn']:
    loss_fn = torch.nn.MSELoss(reduction='none')
    sample_fn = lambda x: torch.sigmoid(x)
if args.dataset in ['fmnist', 'mnist']:
    x_dim = 784
    z_dim = 64
    z0_dim = 64
if args.dataset == 'svhn':
    x_dim = 3072
    z_dim = z0_dim = 50

if args.ae_mode == "ae" or args.ae_mode == "sup":
    ae = models.AutoEncoder(x_dim,
                            z_dim,
                            n_units=[500, 500],
                            sup=(args.ae_mode == "sup")).to(device)
    generator = models.Encoder(z0_dim, z_dim, [500, 500]).to(device)
    discriminator = models.Discriminator(z_dim, [20, 20]).to(device)
if args.ae_mode == "conv":
    ae = models.ConvAutoEncoder(in_channels=1,
                                image_size=(28, 28),
                                activation=None).to(device)
    generator = models.ConvGenerator(z0_dim, ae.z_dim).to(device)
    discriminator = models.ConvDiscriminator(ae.z_dim).to(device)

g_optimizer = optim.Adam(generator.parameters(), lr=1e-2)
d_optimizer = optim.Adam(discriminator.parameters(), lr=1e-3)
ed_optimizer = optim.Adam(ae.parameters(), lr=1e-3)

if args.train_ae: