def main(dataname): model = CAE() name = "CAE" model_dict = torch.load('models/CAE_model', map_location='cpu') model.load_state_dict(model_dict) model.eval EPOCH = 1 BATCH_SIZE_TRAIN = 49950 # dataname = "demonstrations/demo_00_02.pkl" save_model_path = "models/" + name + "_model" best_model_path = "models/" + name + "_best_model" train_data = MotionData(dataname) train_set = DataLoader(dataset=train_data, batch_size=BATCH_SIZE_TRAIN, shuffle=True) for epoch in range(EPOCH): epoch_loss = 0.0 for batch, x in enumerate(train_set): # loss = model(x) z = model.encoder(x).tolist() print(len(z)) # print(epoch, loss.item()) data = np.asarray(z) return data
class Model(object): def __init__(self): self.model = CAE() model_dict = torch.load('models/CAE_model', map_location='cpu') self.model.load_state_dict(model_dict) self.model.eval def decoder(self, img, s, z): img = img / 128.0 - 1.0 img = np.transpose(img, (2, 0, 1)) img = torch.FloatTensor([img]) s = torch.FloatTensor([s]) z = torch.FloatTensor([z]) context = (img, s, z) a_tensor = self.model.decoder(context) a_numpy = a_tensor.detach().numpy()[0] return list(a_numpy)
class Model(object): def __init__(self): self.model = CAE() model_dict = torch.load('models/CAE_best_model', map_location='cpu') self.model.load_state_dict(model_dict) self.model.eval def decoder(self, z, s): if abs(z[0][0]) < 0.01: return [0.0] * 6 # z = np.asarray([z]) z = np.asarray(z) # print(s.shape) z_tensor = torch.FloatTensor(np.concatenate((z,s),axis=1)) # print(z_tensor.shape) a_tensor = self.model.decoder(z_tensor) return a_tensor.tolist() def encoder(self, a, s): x = np.concatenate((a,s),axis=1) x_tensor = torch.FloatTensor(x) z = self.model.encoder(x_tensor) return z.tolist()
def train(opts): device = torch.device("cuda" if use_cuda else "cpu") if opts.arch == 'small': channels = [32, 32, 32, 10] elif opts.arch == 'large': channels = [256, 128, 64, 32] else: raise NotImplementedError('Unknown model architecture') if opts.mode == 'train_mnist': train_loader, valid_loader = get_mnist_loaders(opts.data_dir, opts.bsize, opts.nworkers, opts.sigma, opts.alpha) model = CAE(1, 10, 28, opts.n_prototypes, opts.decoder_arch, channels) elif opts.mode == 'train_cifar': train_loader, valid_loader = get_cifar_loaders(opts.data_dir, opts.bsize, opts.nworkers, opts.sigma, opts.alpha) model = CAE(3, 10, 32, opts.n_prototypes, opts.decoder_arch, channels) elif opts.mode == 'train_fmnist': train_loader, valid_loader = get_fmnist_loaders( opts.data_dir, opts.bsize, opts.nworkers, opts.sigma, opts.alpha) model = CAE(1, 10, 28, opts.n_prototypes, opts.decoder_arch, channels) else: raise NotImplementedError('Unknown train mode') if opts.optim == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=opts.lr, weight_decay=opts.wd) else: raise NotImplementedError("Unknown optim type") criterion = nn.CrossEntropyLoss() start_n_iter = 0 # for choosing the best model best_val_acc = 0.0 model_path = os.path.join(opts.save_path, 'model_latest.net') if opts.resume and os.path.exists(model_path): # restoring training from save_state print('====> Resuming training from previous checkpoint') save_state = torch.load(model_path, map_location='cpu') model.load_state_dict(save_state['state_dict']) start_n_iter = save_state['n_iter'] best_val_acc = save_state['best_val_acc'] opts = save_state['opts'] opts.start_epoch = save_state['epoch'] + 1 model = model.to(device) # for logging logger = TensorboardLogger(opts.start_epoch, opts.log_iter, opts.log_dir) logger.set(['acc', 'loss', 'loss_class', 'loss_ae', 'loss_r1', 'loss_r2']) logger.n_iter = start_n_iter for epoch in range(opts.start_epoch, opts.epochs): model.train() logger.step() valid_sample = torch.stack([ valid_loader.dataset[i][0] for i in random.sample(range(len(valid_loader.dataset)), 10) ]).to(device) for batch_idx, (data, target) in enumerate(train_loader): acc, loss, class_error, ae_error, error_1, error_2 = run_iter( opts, data, target, model, criterion, device) # optimizer step optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(model.parameters(), opts.max_norm) optimizer.step() logger.update(acc, loss, class_error, ae_error, error_1, error_2) val_loss, val_acc, val_class_error, val_ae_error, val_error_1, val_error_2, time_taken = evaluate( opts, model, valid_loader, criterion, device) # log the validation losses logger.log_valid(time_taken, val_acc, val_loss, val_class_error, val_ae_error, val_error_1, val_error_2) print('') # Save the model to disk if val_acc >= best_val_acc: best_val_acc = val_acc save_state = { 'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'n_iter': logger.n_iter, 'opts': opts, 'val_acc': val_acc, 'best_val_acc': best_val_acc } model_path = os.path.join(opts.save_path, 'model_best.net') torch.save(save_state, model_path) prototypes = model.save_prototypes(opts.save_path, 'prototypes_best.png') x = torchvision.utils.make_grid(prototypes, nrow=10, pad_value=1.0) logger.writer.add_image('Prototypes (best)', x, epoch) save_state = { 'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'n_iter': logger.n_iter, 'opts': opts, 'val_acc': val_acc, 'best_val_acc': best_val_acc } model_path = os.path.join(opts.save_path, 'model_latest.net') torch.save(save_state, model_path) prototypes = model.save_prototypes(opts.save_path, 'prototypes_latest.png') x = torchvision.utils.make_grid(prototypes, nrow=10, pad_value=1.0) logger.writer.add_image('Prototypes (latest)', x, epoch) ae_samples = model.get_decoded_pairs_grid(valid_sample) logger.writer.add_image('AE_samples_latest', ae_samples, epoch)