class Agent(): def __init__(self, test=False): # device if torch.cuda.is_available(): self.device = torch.device('cuda') else : self.device = torch.device('cpu') self.model = MLP(state_dim=4,action_num=2,hidden_dim=256).to(self.device) if test: self.load('./pg_best.cpt') # discounted reward self.gamma = 0.99 # optimizer self.optimizer = torch.optim.Adam(self.model.parameters(), lr=3e-3) # saved rewards and actions self.memory = Memory() self.tensorboard = TensorboardLogger('./') def save(self, save_path): print('save model to', save_path) torch.save(self.model.state_dict(), save_path) def load(self, load_path): print('load model from', load_path) self.model.load_state_dict(torch.load(load_path)) def act(self,x,test=False): if not test: # boring type casting x = ((torch.from_numpy(x)).unsqueeze(0)).float().to(self.device) # stochastic sample action_prob = self.model(x) dist = torch.distributions.Categorical(action_prob) action = dist.sample() # memory log_prob self.memory.logprobs.append(dist.log_prob(action)) return action.item() else : self.model.eval() x = ((torch.from_numpy(x)).unsqueeze(0)).float().to(self.device) with torch.no_grad(): action_prob = self.model(x) # a = np.argmax(action_prob.cpu().numpy()) dist = torch.distributions.Categorical(action_prob) action = dist.sample() return action.item() def collect_data(self, state, action, reward): self.memory.actions.append(action) self.memory.rewards.append(torch.tensor(reward)) self.memory.states.append(state) def clear_data(self): self.memory.clear_memory() def update(self): R = 0 advantage_function = [] for t in reversed(range(0, len(self.memory.rewards))): R = R * self.gamma + self.memory.rewards[t] advantage_function.insert(0, R) # turn rewards to pytorch tensor and standardize advantage_function = torch.Tensor(advantage_function).to(self.device) advantage_function = (advantage_function - advantage_function.mean()) / (advantage_function.std() + np.finfo(np.float32).eps) policy_loss = [] for log_prob, reward in zip(self.memory.logprobs, advantage_function): policy_loss.append(-log_prob * reward) # Update network weights self.optimizer.zero_grad() loss = torch.cat(policy_loss).sum() loss.backward() self.optimizer.step() # boring log self.tensorboard.scalar_summary("loss", loss.item()) self.tensorboard.update()
def train(opts): device = torch.device("cuda" if use_cuda else "cpu") if opts.arch == 'small': channels = [32, 32, 32, 10] elif opts.arch == 'large': channels = [256, 128, 64, 32] else: raise NotImplementedError('Unknown model architecture') if opts.mode == 'train_mnist': train_loader, valid_loader = get_mnist_loaders(opts.data_dir, opts.bsize, opts.nworkers, opts.sigma, opts.alpha) model = CAE(1, 10, 28, opts.n_prototypes, opts.decoder_arch, channels) elif opts.mode == 'train_cifar': train_loader, valid_loader = get_cifar_loaders(opts.data_dir, opts.bsize, opts.nworkers, opts.sigma, opts.alpha) model = CAE(3, 10, 32, opts.n_prototypes, opts.decoder_arch, channels) elif opts.mode == 'train_fmnist': train_loader, valid_loader = get_fmnist_loaders( opts.data_dir, opts.bsize, opts.nworkers, opts.sigma, opts.alpha) model = CAE(1, 10, 28, opts.n_prototypes, opts.decoder_arch, channels) else: raise NotImplementedError('Unknown train mode') if opts.optim == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=opts.lr, weight_decay=opts.wd) else: raise NotImplementedError("Unknown optim type") criterion = nn.CrossEntropyLoss() start_n_iter = 0 # for choosing the best model best_val_acc = 0.0 model_path = os.path.join(opts.save_path, 'model_latest.net') if opts.resume and os.path.exists(model_path): # restoring training from save_state print('====> Resuming training from previous checkpoint') save_state = torch.load(model_path, map_location='cpu') model.load_state_dict(save_state['state_dict']) start_n_iter = save_state['n_iter'] best_val_acc = save_state['best_val_acc'] opts = save_state['opts'] opts.start_epoch = save_state['epoch'] + 1 model = model.to(device) # for logging logger = TensorboardLogger(opts.start_epoch, opts.log_iter, opts.log_dir) logger.set(['acc', 'loss', 'loss_class', 'loss_ae', 'loss_r1', 'loss_r2']) logger.n_iter = start_n_iter for epoch in range(opts.start_epoch, opts.epochs): model.train() logger.step() valid_sample = torch.stack([ valid_loader.dataset[i][0] for i in random.sample(range(len(valid_loader.dataset)), 10) ]).to(device) for batch_idx, (data, target) in enumerate(train_loader): acc, loss, class_error, ae_error, error_1, error_2 = run_iter( opts, data, target, model, criterion, device) # optimizer step optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(model.parameters(), opts.max_norm) optimizer.step() logger.update(acc, loss, class_error, ae_error, error_1, error_2) val_loss, val_acc, val_class_error, val_ae_error, val_error_1, val_error_2, time_taken = evaluate( opts, model, valid_loader, criterion, device) # log the validation losses logger.log_valid(time_taken, val_acc, val_loss, val_class_error, val_ae_error, val_error_1, val_error_2) print('') # Save the model to disk if val_acc >= best_val_acc: best_val_acc = val_acc save_state = { 'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'n_iter': logger.n_iter, 'opts': opts, 'val_acc': val_acc, 'best_val_acc': best_val_acc } model_path = os.path.join(opts.save_path, 'model_best.net') torch.save(save_state, model_path) prototypes = model.save_prototypes(opts.save_path, 'prototypes_best.png') x = torchvision.utils.make_grid(prototypes, nrow=10, pad_value=1.0) logger.writer.add_image('Prototypes (best)', x, epoch) save_state = { 'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'n_iter': logger.n_iter, 'opts': opts, 'val_acc': val_acc, 'best_val_acc': best_val_acc } model_path = os.path.join(opts.save_path, 'model_latest.net') torch.save(save_state, model_path) prototypes = model.save_prototypes(opts.save_path, 'prototypes_latest.png') x = torchvision.utils.make_grid(prototypes, nrow=10, pad_value=1.0) logger.writer.add_image('Prototypes (latest)', x, epoch) ae_samples = model.get_decoded_pairs_grid(valid_sample) logger.writer.add_image('AE_samples_latest', ae_samples, epoch)