Exemple #1
0
def main(dataname):

    model = CAE()
    name = "CAE"

    model_dict = torch.load('models/CAE_model', map_location='cpu')
    model.load_state_dict(model_dict)
    model.eval

    EPOCH = 1
    BATCH_SIZE_TRAIN = 49950

    # dataname = "demonstrations/demo_00_02.pkl"
    save_model_path = "models/" + name + "_model"
    best_model_path = "models/" + name + "_best_model"

    train_data = MotionData(dataname)
    train_set = DataLoader(dataset=train_data,
                           batch_size=BATCH_SIZE_TRAIN,
                           shuffle=True)

    for epoch in range(EPOCH):
        epoch_loss = 0.0
        for batch, x in enumerate(train_set):
            # loss = model(x)
            z = model.encoder(x).tolist()
            print(len(z))
        # print(epoch, loss.item())
    data = np.asarray(z)
    return data
Exemple #2
0
class Model(object):
    def __init__(self):
        self.model = CAE()
        model_dict = torch.load('models/CAE_model', map_location='cpu')
        self.model.load_state_dict(model_dict)
        self.model.eval

    def decoder(self, img, s, z):
        img = img / 128.0 - 1.0
        img = np.transpose(img, (2, 0, 1))
        img = torch.FloatTensor([img])
        s = torch.FloatTensor([s])
        z = torch.FloatTensor([z])
        context = (img, s, z)
        a_tensor = self.model.decoder(context)
        a_numpy = a_tensor.detach().numpy()[0]
        return list(a_numpy)
Exemple #3
0
class Model(object):

    def __init__(self):
        self.model = CAE()
        model_dict = torch.load('models/CAE_best_model', map_location='cpu')
        self.model.load_state_dict(model_dict)
        self.model.eval

    def decoder(self, z, s):
        if abs(z[0][0]) < 0.01:
             return [0.0] * 6
        # z = np.asarray([z])
        z = np.asarray(z)
        # print(s.shape)
        z_tensor = torch.FloatTensor(np.concatenate((z,s),axis=1))
        # print(z_tensor.shape)
        a_tensor = self.model.decoder(z_tensor)
        return a_tensor.tolist()
    def encoder(self, a, s):
        
        x = np.concatenate((a,s),axis=1)
        x_tensor = torch.FloatTensor(x)
        z = self.model.encoder(x_tensor)
        return z.tolist()
Exemple #4
0
def train(opts):

    device = torch.device("cuda" if use_cuda else "cpu")

    if opts.arch == 'small':
        channels = [32, 32, 32, 10]
    elif opts.arch == 'large':
        channels = [256, 128, 64, 32]
    else:
        raise NotImplementedError('Unknown model architecture')

    if opts.mode == 'train_mnist':
        train_loader, valid_loader = get_mnist_loaders(opts.data_dir,
                                                       opts.bsize,
                                                       opts.nworkers,
                                                       opts.sigma, opts.alpha)
        model = CAE(1, 10, 28, opts.n_prototypes, opts.decoder_arch, channels)
    elif opts.mode == 'train_cifar':
        train_loader, valid_loader = get_cifar_loaders(opts.data_dir,
                                                       opts.bsize,
                                                       opts.nworkers,
                                                       opts.sigma, opts.alpha)
        model = CAE(3, 10, 32, opts.n_prototypes, opts.decoder_arch, channels)
    elif opts.mode == 'train_fmnist':
        train_loader, valid_loader = get_fmnist_loaders(
            opts.data_dir, opts.bsize, opts.nworkers, opts.sigma, opts.alpha)
        model = CAE(1, 10, 28, opts.n_prototypes, opts.decoder_arch, channels)
    else:
        raise NotImplementedError('Unknown train mode')

    if opts.optim == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=opts.lr,
                                     weight_decay=opts.wd)
    else:
        raise NotImplementedError("Unknown optim type")
    criterion = nn.CrossEntropyLoss()

    start_n_iter = 0
    # for choosing the best model
    best_val_acc = 0.0

    model_path = os.path.join(opts.save_path, 'model_latest.net')
    if opts.resume and os.path.exists(model_path):
        # restoring training from save_state
        print('====> Resuming training from previous checkpoint')
        save_state = torch.load(model_path, map_location='cpu')
        model.load_state_dict(save_state['state_dict'])
        start_n_iter = save_state['n_iter']
        best_val_acc = save_state['best_val_acc']
        opts = save_state['opts']
        opts.start_epoch = save_state['epoch'] + 1

    model = model.to(device)

    # for logging
    logger = TensorboardLogger(opts.start_epoch, opts.log_iter, opts.log_dir)
    logger.set(['acc', 'loss', 'loss_class', 'loss_ae', 'loss_r1', 'loss_r2'])
    logger.n_iter = start_n_iter

    for epoch in range(opts.start_epoch, opts.epochs):
        model.train()
        logger.step()
        valid_sample = torch.stack([
            valid_loader.dataset[i][0]
            for i in random.sample(range(len(valid_loader.dataset)), 10)
        ]).to(device)

        for batch_idx, (data, target) in enumerate(train_loader):
            acc, loss, class_error, ae_error, error_1, error_2 = run_iter(
                opts, data, target, model, criterion, device)

            # optimizer step
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), opts.max_norm)
            optimizer.step()

            logger.update(acc, loss, class_error, ae_error, error_1, error_2)

        val_loss, val_acc, val_class_error, val_ae_error, val_error_1, val_error_2, time_taken = evaluate(
            opts, model, valid_loader, criterion, device)
        # log the validation losses
        logger.log_valid(time_taken, val_acc, val_loss, val_class_error,
                         val_ae_error, val_error_1, val_error_2)
        print('')

        # Save the model to disk
        if val_acc >= best_val_acc:
            best_val_acc = val_acc
            save_state = {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'n_iter': logger.n_iter,
                'opts': opts,
                'val_acc': val_acc,
                'best_val_acc': best_val_acc
            }
            model_path = os.path.join(opts.save_path, 'model_best.net')
            torch.save(save_state, model_path)
            prototypes = model.save_prototypes(opts.save_path,
                                               'prototypes_best.png')
            x = torchvision.utils.make_grid(prototypes, nrow=10, pad_value=1.0)
            logger.writer.add_image('Prototypes (best)', x, epoch)

        save_state = {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'n_iter': logger.n_iter,
            'opts': opts,
            'val_acc': val_acc,
            'best_val_acc': best_val_acc
        }
        model_path = os.path.join(opts.save_path, 'model_latest.net')
        torch.save(save_state, model_path)
        prototypes = model.save_prototypes(opts.save_path,
                                           'prototypes_latest.png')
        x = torchvision.utils.make_grid(prototypes, nrow=10, pad_value=1.0)
        logger.writer.add_image('Prototypes (latest)', x, epoch)
        ae_samples = model.get_decoded_pairs_grid(valid_sample)
        logger.writer.add_image('AE_samples_latest', ae_samples, epoch)