def main() -> None: args = get_args() config = get_bunch_config_from_json(args.config) comet_experiment = Experiment( api_key=config.comet_api_key, project_name=config.comet_project_name, workspace=config.comet_workspace, disabled=not config.use_comet_experiments, ) comet_experiment.set_name(config.experiment_name) comet_experiment.log_parameters(config) test_tweets = load_test_tweets(config.test_data_path) client = LanguageServiceClient() result = [] predictions = np.zeros(len(test_tweets), dtype=np.int32) for i, tweet in enumerate(test_tweets): start_iter_timestamp = time.time() document = types.Document( type=enums.Document.Type.PLAIN_TEXT, content=tweet, language="en" ) response = client.analyze_sentiment(document=document) response_dict = MessageToDict(response) result.append(response_dict) prediction_present = bool(response_dict["documentSentiment"]) if prediction_present: # -1, 1 predictions predictions[i] = 2 * (response.document_sentiment.score > 0) - 1 print("iteration", i, "took:", time.time() - start_iter_timestamp, "seconds") comet_experiment.log_asset_data(result, name="google_nlp_api_response.json") ids = np.arange(1, len(test_tweets) + 1).astype(np.int32) predictions_table = np.column_stack((ids, predictions)) if comet_experiment.disabled: save_path = build_save_path(config) os.makedirs(save_path) formatted_predictions_table = pd.DataFrame( predictions_table, columns=["Id", "Prediction"], dtype=np.int32, ) formatted_predictions_table.to_csv( os.path.join(save_path, "google_nlp_api_predictions.csv"), index=False ) else: comet_experiment.log_table( filename="google_nlp_api_predictions.csv", tabular_data=predictions_table, headers=["Id", "Prediction"], ) percentage_predicted = np.sum(predictions != 0) / predictions.shape[0] comet_experiment.log_metric(name="percentage predicted", value=percentage_predicted)
def init_experiment(experiment: Experiment, dataset: Dataset): """ Initializes an experiment by logging the template and the validation set ground truths if they have not already been logged. """ api_experiment = APIExperiment(previous_experiment=experiment.id) try: api_experiment.get_asset("datatap/template.json") except NotFound: experiment.log_asset_data([ annotation.to_json() for annotation in dataset.stream_split("validation") ], name="datatap/validation/ground_truth.json") experiment.log_asset_data(dataset.template.to_json(), name="datatap/template.json")
class Trainer(): def __init__(self, log_dir, cfg): self.path = log_dir self.cfg = cfg if cfg.TRAIN.FLAG: self.model_dir = os.path.join(self.path, 'Model') self.log_dir = os.path.join(self.path, 'Log') mkdir_p(self.model_dir) mkdir_p(self.log_dir) self.writer = SummaryWriter(log_dir=self.log_dir) self.logfile = os.path.join(self.path, "logfile.log") sys.stdout = Logger(logfile=self.logfile) self.data_dir = cfg.DATASET.DATA_DIR self.max_epochs = cfg.TRAIN.MAX_EPOCHS self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL s_gpus = cfg.GPU_ID.split(',') self.gpus = [int(ix) for ix in s_gpus] self.num_gpus = len(self.gpus) self.batch_size = cfg.TRAIN.BATCH_SIZE self.lr = cfg.TRAIN.LEARNING_RATE torch.cuda.set_device(self.gpus[0]) cudnn.benchmark = True sample = cfg.SAMPLE self.dataset = [] self.dataloader = [] self.use_feats = cfg.model.use_feats eval_split = cfg.EVAL if cfg.EVAL else 'val' train_split = cfg.DATASET.train_split if cfg.DATASET.DATASET == 'clevr': clevr_collate_fn = collate_fn cogent = cfg.DATASET.COGENT if cogent: print(f'Using CoGenT {cogent.upper()}') if cfg.TRAIN.FLAG: self.dataset = ClevrDataset(data_dir=self.data_dir, split=train_split + cogent, sample=sample, **cfg.DATASET.params) self.dataloader = DataLoader(dataset=self.dataset, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=True, num_workers=cfg.WORKERS, drop_last=True, collate_fn=clevr_collate_fn) self.dataset_val = ClevrDataset(data_dir=self.data_dir, split=eval_split + cogent, sample=sample, **cfg.DATASET.params) self.dataloader_val = DataLoader(dataset=self.dataset_val, batch_size=cfg.TEST_BATCH_SIZE, drop_last=False, shuffle=False, num_workers=cfg.WORKERS, collate_fn=clevr_collate_fn) elif cfg.DATASET.DATASET == 'gqa': if self.use_feats == 'spatial': gqa_collate_fn = collate_fn_gqa elif self.use_feats == 'objects': gqa_collate_fn = collate_fn_gqa_objs if cfg.TRAIN.FLAG: self.dataset = GQADataset(data_dir=self.data_dir, split=train_split, sample=sample, use_feats=self.use_feats, **cfg.DATASET.params) self.dataloader = DataLoader(dataset=self.dataset, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=True, num_workers=cfg.WORKERS, drop_last=True, collate_fn=gqa_collate_fn) self.dataset_val = GQADataset(data_dir=self.data_dir, split=eval_split, sample=sample, use_feats=self.use_feats, **cfg.DATASET.params) self.dataloader_val = DataLoader(dataset=self.dataset_val, batch_size=cfg.TEST_BATCH_SIZE, shuffle=False, num_workers=cfg.WORKERS, drop_last=False, collate_fn=gqa_collate_fn) # load model self.vocab = load_vocab(cfg) self.model, self.model_ema = mac.load_MAC(cfg, self.vocab) self.weight_moving_average(alpha=0) if cfg.TRAIN.RADAM: self.optimizer = RAdam(self.model.parameters(), lr=self.lr) else: self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) self.start_epoch = 0 if cfg.resume_model: location = 'cuda' if cfg.CUDA else 'cpu' state = torch.load(cfg.resume_model, map_location=location) self.model.load_state_dict(state['model']) self.optimizer.load_state_dict(state['optim']) self.start_epoch = state['iter'] + 1 state = torch.load(cfg.resume_model_ema, map_location=location) self.model_ema.load_state_dict(state['model']) if cfg.start_epoch is not None: self.start_epoch = cfg.start_epoch self.previous_best_acc = 0.0 self.previous_best_epoch = 0 self.previous_best_loss = 100 self.previous_best_loss_epoch = 0 self.total_epoch_loss = 0 self.prior_epoch_loss = 10 self.print_info() self.loss_fn = torch.nn.CrossEntropyLoss().cuda() self.comet_exp = Experiment( project_name=cfg.COMET_PROJECT_NAME, api_key=os.getenv('COMET_API_KEY'), workspace=os.getenv('COMET_WORKSPACE'), disabled=cfg.logcomet is False, ) if cfg.logcomet: exp_name = cfg_to_exp_name(cfg) print(exp_name) self.comet_exp.set_name(exp_name) self.comet_exp.log_parameters(flatten_json_iterative_solution(cfg)) self.comet_exp.log_asset(self.logfile) self.comet_exp.log_asset_data(json.dumps(cfg, indent=4), file_name='cfg.json') self.comet_exp.set_model_graph(str(self.model)) if cfg.cfg_file: self.comet_exp.log_asset(cfg.cfg_file) with open(os.path.join(self.path, 'cfg.json'), 'w') as f: json.dump(cfg, f, indent=4) def print_info(self): print('Using config:') pprint.pprint(self.cfg) print("\n") pprint.pprint("Size of train dataset: {}".format(len(self.dataset))) # print("\n") pprint.pprint("Size of val dataset: {}".format(len(self.dataset_val))) print("\n") print("Using MAC-Model:") pprint.pprint(self.model) print("\n") def weight_moving_average(self, alpha=0.999): for param1, param2 in zip(self.model_ema.parameters(), self.model.parameters()): param1.data *= alpha param1.data += (1.0 - alpha) * param2.data def set_mode(self, mode="train"): if mode == "train": self.model.train() self.model_ema.train() else: self.model.eval() self.model_ema.eval() def reduce_lr(self): epoch_loss = self.total_epoch_loss # / float(len(self.dataset) // self.batch_size) lossDiff = self.prior_epoch_loss - epoch_loss if ((lossDiff < 0.015 and self.prior_epoch_loss < 0.5 and self.lr > 0.00002) or \ (lossDiff < 0.008 and self.prior_epoch_loss < 0.15 and self.lr > 0.00001) or \ (lossDiff < 0.003 and self.prior_epoch_loss < 0.10 and self.lr > 0.000005)): self.lr *= 0.5 print("Reduced learning rate to {}".format(self.lr)) for param_group in self.optimizer.param_groups: param_group['lr'] = self.lr self.prior_epoch_loss = epoch_loss self.total_epoch_loss = 0 def save_models(self, iteration): save_model(self.model, self.optimizer, iteration, self.model_dir, model_name="model") save_model(self.model_ema, None, iteration, self.model_dir, model_name="model_ema") def train_epoch(self, epoch): cfg = self.cfg total_loss = 0. total_correct = 0 total_samples = 0 self.labeled_data = iter(self.dataloader) self.set_mode("train") dataset = tqdm(self.labeled_data, total=len(self.dataloader), ncols=20) for data in dataset: ###################################################### # (1) Prepare training data ###################################################### image, question, question_len, answer = data['image'], data[ 'question'], data['question_length'], data['answer'] answer = answer.long() question = Variable(question) answer = Variable(answer) if cfg.CUDA: if self.use_feats == 'spatial': image = image.cuda() elif self.use_feats == 'objects': image = [e.cuda() for e in image] question = question.cuda() answer = answer.cuda().squeeze() else: question = question image = image answer = answer.squeeze() ############################ # (2) Train Model ############################ self.optimizer.zero_grad() scores = self.model(image, question, question_len) loss = self.loss_fn(scores, answer) loss.backward() if self.cfg.TRAIN.CLIP_GRADS: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.TRAIN.CLIP) self.optimizer.step() self.weight_moving_average() ############################ # (3) Log Progress ############################ correct = scores.detach().argmax(1) == answer total_correct += correct.sum().cpu().item() total_loss += loss.item() * answer.size(0) total_samples += answer.size(0) avg_loss = total_loss / total_samples train_accuracy = total_correct / total_samples # accuracy = correct.sum().cpu().numpy() / answer.shape[0] # if avg_loss == 0: # avg_loss = loss.item() # train_accuracy = accuracy # else: # avg_loss = 0.99 * avg_loss + 0.01 * loss.item() # train_accuracy = 0.99 * train_accuracy + 0.01 * accuracy # self.total_epoch_loss += loss.item() * answer.size(0) dataset.set_description( 'Epoch: {}; Avg Loss: {:.5f}; Avg Train Acc: {:.5f}'.format( epoch + 1, avg_loss, train_accuracy)) self.total_epoch_loss = avg_loss dict = { "loss": avg_loss, "accuracy": train_accuracy, "avg_loss": avg_loss, # For commet "avg_accuracy": train_accuracy, # For commet } return dict def train(self): cfg = self.cfg print("Start Training") for epoch in range(self.start_epoch, self.max_epochs): with self.comet_exp.train(): dict = self.train_epoch(epoch) self.reduce_lr() dict['epoch'] = epoch + 1 dict['lr'] = self.lr self.comet_exp.log_metrics( dict, epoch=epoch + 1, ) with self.comet_exp.validate(): dict = self.log_results(epoch, dict) dict['epoch'] = epoch + 1 dict['lr'] = self.lr self.comet_exp.log_metrics( dict, epoch=epoch + 1, ) if cfg.TRAIN.EALRY_STOPPING: if epoch - cfg.TRAIN.PATIENCE == self.previous_best_epoch: # if epoch - cfg.TRAIN.PATIENCE == self.previous_best_loss_epoch: print('Early stop') break self.comet_exp.log_asset(self.logfile) self.save_models(self.max_epochs) self.writer.close() print("Finished Training") print( f"Highest validation accuracy: {self.previous_best_acc} at epoch {self.previous_best_epoch}" ) def log_results(self, epoch, dict, max_eval_samples=None): epoch += 1 self.writer.add_scalar("avg_loss", dict["loss"], epoch) self.writer.add_scalar("train_accuracy", dict["accuracy"], epoch) metrics = self.calc_accuracy("validation", max_samples=max_eval_samples) self.writer.add_scalar("val_accuracy_ema", metrics['acc_ema'], epoch) self.writer.add_scalar("val_accuracy", metrics['acc'], epoch) self.writer.add_scalar("val_loss_ema", metrics['loss_ema'], epoch) self.writer.add_scalar("val_loss", metrics['loss'], epoch) print( "Epoch: {epoch}\tVal Acc: {acc},\tVal Acc EMA: {acc_ema},\tAvg Loss: {loss},\tAvg Loss EMA: {loss_ema},\tLR: {lr}" .format(epoch=epoch, lr=self.lr, **metrics)) if metrics['acc'] > self.previous_best_acc: self.previous_best_acc = metrics['acc'] self.previous_best_epoch = epoch if metrics['loss'] < self.previous_best_loss: self.previous_best_loss = metrics['loss'] self.previous_best_loss_epoch = epoch if epoch % self.snapshot_interval == 0: self.save_models(epoch) return metrics def calc_accuracy(self, mode="train", max_samples=None): self.set_mode("validation") if mode == "train": loader = self.dataloader # elif (mode == "validation") or (mode == 'test'): # loader = self.dataloader_val else: loader = self.dataloader_val total_correct = 0 total_correct_ema = 0 total_samples = 0 total_loss = 0. total_loss_ema = 0. pbar = tqdm(loader, total=len(loader), desc=mode.upper(), ncols=20) for data in pbar: image, question, question_len, answer = data['image'], data[ 'question'], data['question_length'], data['answer'] answer = answer.long() question = Variable(question) answer = Variable(answer) if self.cfg.CUDA: if self.use_feats == 'spatial': image = image.cuda() elif self.use_feats == 'objects': image = [e.cuda() for e in image] question = question.cuda() answer = answer.cuda().squeeze() with torch.no_grad(): scores = self.model(image, question, question_len) scores_ema = self.model_ema(image, question, question_len) loss = self.loss_fn(scores, answer) loss_ema = self.loss_fn(scores_ema, answer) correct = scores.detach().argmax(1) == answer correct_ema = scores_ema.detach().argmax(1) == answer total_correct += correct.sum().cpu().item() total_correct_ema += correct_ema.sum().cpu().item() total_loss += loss.item() * answer.size(0) total_loss_ema += loss_ema.item() * answer.size(0) total_samples += answer.size(0) avg_acc = total_correct / total_samples avg_acc_ema = total_correct_ema / total_samples avg_loss = total_loss / total_samples avg_loss_ema = total_loss_ema / total_samples pbar.set_postfix({ 'Acc': f'{avg_acc:.5f}', 'Acc Ema': f'{avg_acc_ema:.5f}', 'Loss': f'{avg_loss:.5f}', 'Loss Ema': f'{avg_loss_ema:.5f}', }) return dict(acc=avg_acc, acc_ema=avg_acc_ema, loss=avg_loss, loss_ema=avg_loss_ema)
print('Model loaded.') n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) log_dict = helpers.flatten_dict(config) log_dict.update({'trainable_params': n_params}) exp.log_parameters(log_dict) test_dataset = data.CSVDatasetsMerger(helpers.get_datasets_paths(config, 'test')) test_dataloader = DataLoader(test_dataset, batch_size=config['evaluation']['eval_batch_size'], shuffle=False, drop_last=False, num_workers=config['evaluation']['n_eval_workers'], collate_fn=text_proc) evaluator = Evaluation(test_dataloader, config) print('Testing ...') results, assets, image_fns = evaluator.eval_model(model, finished_training=True) print('Finished testing. Uploading ...') exp.log_metrics(results, step=0, epoch=0) [exp.log_asset_data(asset, step=0) for asset in assets] [exp.log_image(fn, step=0) for fn in image_fns] print('Finished uploading.')
def log_validation_proposals(experiment: Experiment, proposals: Sequence[ImageAnnotation]): experiment.log_asset_data( [annotation.to_json() for annotation in proposals], name="datatap/validation/proposals.json")
class Plotter: """ Handles plotting and logging to comet. Args: exp_args (args.parse_args): arguments for the experiment agent_args (dict): arguments for the agent agent (Agent): the agent """ def __init__(self, exp_args, agent_args, agent): self.exp_args = exp_args self.agent_args = agent_args self.agent = agent self.experiment = None if self.exp_args.plotting: self.experiment = Experiment(api_key=LOGGING_API_KEY, project_name=PROJECT_NAME, workspace=WORKSPACE) self.experiment.disable_mp() self.experiment.log_parameters(get_arg_dict(exp_args)) self.experiment.log_parameters(flatten_arg_dict(agent_args)) self.experiment.log_asset_data(json.dumps(get_arg_dict(exp_args)), name='exp_args') self.experiment.log_asset_data(json.dumps(agent_args), name='agent_args') if self.exp_args.checkpoint_exp_key is not None: self.load_checkpoint() self.result_dict = None # keep a hard-coded list of returns in case Comet fails self.returns = [] def _plot_ts(self, key, observations, statistics, label, color): dim_obs = min(observations.shape[1], 9) k = 1 for i in range(dim_obs): plt.subplot(int(str(dim_obs) + '1' + str(k))) observations_i = observations[:-1, i].cpu().numpy() if key == 'action' and self.agent.postprocess_action: observations_i = np.tanh(observations_i) plt.plot(observations_i.squeeze(), 'o', label='observation', color='k', markersize=2) if len(statistics) == 1: # Bernoulli distribution probs = statistics['probs'] probs = probs.cpu().numpy() plt.plot(probs, label=label, color=color) elif len(statistics) == 2: if 'loc' in statistics: # Normal distribution mean = statistics['loc'] std = statistics['scale'] mean = mean[:, i].cpu().numpy() std = std[:, i].cpu().numpy() mean = mean.squeeze() std = std.squeeze() x, plus, minus = mean, mean + std, mean - std if key == 'action' and label == 'approx_post' and self.agent_args['approx_post_args']['dist_type'] in ['TanhNormal', 'TanhARNormal']: # Tanh Normal distribution x, plus, minus = np.tanh(x), np.tanh(plus), np.tanh(minus) if key == 'action' and label == 'direct_approx_post' and self.agent_args['approx_post_args']['dist_type'] in ['TanhNormal', 'TanhARNormal']: # Tanh Normal distribution x, plus, minus = np.tanh(x), np.tanh(plus), np.tanh(minus) if key == 'action' and label == 'prior' and self.agent_args['prior_args']['dist_type'] in ['TanhNormal', 'TanhARNormal']: # Tanh Normal distribution x, plus, minus = np.tanh(x), np.tanh(plus), np.tanh(minus) if key == 'action' and self.agent.postprocess_action: x, plus, minus = np.tanh(x), np.tanh(plus), np.tanh(minus) if key == 'action' and label == 'prior' and self.agent_args['prior_args']['dist_type'] == 'NormalUniform': # Normal + Uniform distribution x, plus, minus = x, np.minimum(plus, 1.), np.maximum(minus, -1) elif 'low' in statistics: # Uniform distribution low = statistics['low'][:, i].cpu().numpy() high = statistics['high'][:, i].cpu().numpy() x = low + (high - low) / 2 plus, minus = x + high, x + low else: raise NotImplementedError plt.plot(x, label=label, color=color) plt.fill_between(np.arange(len(x)), plus, minus, color=color, alpha=0.2, label=label) else: NotImplementedError k += 1 def plot_states_and_rewards(self, states, rewards, step): """ Plots the states and rewards for a collected episode. """ # states plt.figure() dim_obs = states.shape[1] for i in range(dim_obs): plt.subplot(dim_obs, 1, i+1) states_i = states[:-1, i].cpu().numpy() plt.plot(states_i.squeeze(), 'o', label='state', color='k', markersize=2) self.experiment.log_figure(figure=plt, figure_name='states_ts_'+str(step)) plt.close() # rewards plt.figure() rewards = rewards[:-1, 0].cpu().numpy() plt.plot(rewards.squeeze(), 'o', label='reward', color='k', markersize=2) self.experiment.log_figure(figure=plt, figure_name='rewards_ts_'+str(step)) plt.close() def plot_episode(self, episode, step): """ Plots a newly collected episode. """ if self.exp_args.plotting: self.experiment.log_metric('cumulative_reward', episode['reward'].sum(), step) def merge_legends(): handles, labels = plt.gca().get_legend_handles_labels() newLabels, newHandles = [], [] for handle, label in zip(handles, labels): if label not in newLabels: newLabels.append(label) newHandles.append(handle) plt.legend(newHandles, newLabels) for k in episode['distributions'].keys(): for i, l in enumerate(episode['distributions'][k].keys()): color = COLORS[i] self._plot_ts(k, episode[k], episode['distributions'][k][l], l, color) plt.suptitle(k) merge_legends() self.experiment.log_figure(figure=plt, figure_name=k + '_ts_'+str(step)) plt.close() self.plot_states_and_rewards(episode['state'], episode['reward'], step) def log_eval(self, episode, eval_states, step): """ Plots an evaluation episode performance. Logs the episode. Args: episode (dict): dictionary containing agent's collected episode eval_states (dict): dictionary of MuJoCo simulator states step (int): the current step number in training """ # plot and log eval returns eval_return = episode['reward'].sum() print(' Eval. Return at Step ' + str(step) + ': ' + str(eval_return.item())) self.returns.append(eval_return.item()) if self.exp_args.plotting: self.experiment.log_metric('eval_cumulative_reward', eval_return, step) json_str = json.dumps(self.returns) self.experiment.log_asset_data(json_str, name='eval_returns', overwrite=True) # log the episode itself for ep_item_str in ['state', 'action', 'reward']: ep_item = episode[ep_item_str].tolist() json_str = json.dumps(ep_item) item_name = 'episode_step_' + str(step) + '_' + ep_item_str self.experiment.log_asset_data(json_str, name=item_name) # log the MuJoCo simulator states for sim_item_str in ['qpos', 'qvel']: if len(eval_states[sim_item_str]) > 0: sim_item = eval_states[sim_item_str].tolist() json_str = json.dumps(sim_item) item_name = 'episode_step_' + str(step) + '_' + sim_item_str self.experiment.log_asset_data(json_str, name=item_name) def plot_agent_kl(self, agent_kl, step): if self.exp_args.plotting: self.experiment.log_metric('agent_kl', agent_kl, step) def log_results(self, results): """ Log the results dictionary. """ if self.result_dict is None: self.result_dict = {} for k, v in flatten_arg_dict(results).items(): if k not in self.result_dict: self.result_dict[k] = [v] else: self.result_dict[k].append(v) def plot_results(self, timestep): """ Plot/log the results to Comet. """ if self.exp_args.plotting: for k, v in self.result_dict.items(): avg_value = np.mean(v) self.experiment.log_metric(k, avg_value, timestep) self.result_dict = None def plot_model_eval(self, episode, predictions, log_likelihoods, step): """ Plot/log the results from model evaluation. Args: episode (dict): a collected episode predictions (dict): predictions from each state, containing [n_steps, horizon, n_dims] log_likelihoods (dict): log-likelihood evaluations of predictions, containing [n_steps, horizon, 1] """ if self.exp_args.plotting: for variable, lls in log_likelihoods.items(): # average the log-likelihood estimates and plot the result at the horizon length mean_ll = lls[:, -1].mean().item() self.experiment.log_metric(variable + '_pred_log_likelihood', mean_ll, step) # plot log-likelihoods as a function of rollout step plt.figure() mean = lls.mean(dim=0).view(-1) std = lls.std(dim=0).view(-1) plt.plot(mean.numpy()) lower = mean - std upper = mean + std plt.fill_between(np.arange(lls.shape[1]), lower.numpy(), upper.numpy(), alpha=0.2) plt.xlabel('Rollout Step') plt.ylabel('Prediction Log-Likelihood') plt.xticks(np.arange(lls.shape[1])) self.experiment.log_figure(figure=plt, figure_name=variable + '_pred_ll_' + str(step)) plt.close() # plot predictions vs. actual values for an arbitrary time step time_step = np.random.randint(predictions['state']['loc'].shape[0]) for variable, preds in predictions.items(): pred_loc = preds['loc'][time_step] pred_scale = preds['scale'][time_step] x = episode[variable][time_step+1:time_step+1+pred_loc.shape[0]] plt.figure() horizon, n_dims = pred_loc.shape for plot_num in range(n_dims): plt.subplot(n_dims, 1, plot_num + 1) plt.plot(pred_loc[:, plot_num].numpy()) lower = pred_loc[:, plot_num] - pred_scale[:, plot_num] upper = pred_loc[:, plot_num] + pred_scale[:, plot_num] plt.fill_between(np.arange(horizon), lower.numpy(), upper.numpy(), alpha=0.2) plt.plot(x[:, plot_num].numpy(), '.') plt.xlabel('Rollout Step') plt.xticks(np.arange(horizon)) self.experiment.log_figure(figure=plt, figure_name=variable + '_pred_' + str(step)) plt.close() def save_checkpoint(self, step): """ Checkpoint the model by getting the state dictionary for each component. """ if self.exp_args.plotting: print('Checkpointing the agent...') state_dict = self.agent.state_dict() cpu_state_dict = {k: v.cpu() for k, v in state_dict.items()} ckpt_path = os.path.join('./ckpt_step_'+ str(step) + '.ckpt') torch.save(cpu_state_dict, ckpt_path) self.experiment.log_asset(ckpt_path) os.remove(ckpt_path) print('Done.') def load_checkpoint(self, timestep=None): """ Loads a checkpoint from Comet. Args: timestep (int, optional): the checkpoint timestep, default is latest """ load_checkpoint(self.agent, self.exp_args.checkpoint_exp_key, timestep)
class Trainer: def __init__(self, cfg, log_dir): self.log_dir = log_dir self.cfg = cfg if log_dir: self.log = True self.model_dir = osp.join(log_dir, "checkpoints") mkdir_p(self.model_dir) self.logfile = osp.join(log_dir, "logfile.log") sys.stdout = Logger(logfile=self.logfile) self.summary_writer = tf.summary.create_file_writer(log_dir) else: self.log = False self.generator = Generator(**cfg.model.generator) self.discriminator = Discriminator(**cfg.model.discriminator) self.g_optimizer = optimizers.Adam(**cfg.train.generator.optimizer) self.d_optimizer = optimizers.Adam(**cfg.train.discriminator.optimizer) self.bce = losses.BinaryCrossentropy(from_logits=True) # TODO resume model self.comet = Experiment( api_key=os.getenv("COMET_API_KEY"), workspace=os.getenv("COMET_WORKSPACE"), project_name=cfg.comet_project_name, disabled=cfg.logcomet is False or not self.log, ) self.comet.set_name(f"{cfg.experiment_name}/{cfg.run_name}") self.comet.log_parameters(flatten_json_iterative_solution(self.cfg)) self.comet.log_asset_data(json.dumps(self.cfg, indent=4), file_name="cfg.json") if cfg.cfg_file: self.comet.log_asset(cfg.cfg_file) self.start_epoch = tf.Variable(0) self.curr_step = tf.Variable(0) self.ckpt = tf.train.Checkpoint( generator=self.generator, discriminator=self.discriminator, g_optimizer=self.g_optimizer, d_optimizer=self.d_optimizer, start_epoch=self.start_epoch, curr_step=self.curr_step, ) if cfg.train.resume: ckpt_resumer = tf.train.CheckpointManager( self.ckpt, cfg.train.resume, max_to_keep=3, ) # if a checkpoint exists, restore the latest checkpoint. if ckpt_resumer.latest_checkpoint: self.ckpt.restore(ckpt_resumer.latest_checkpoint) print("Latest checkpoint restored!!", ckpt_resumer.latest_checkpoint) print( f"Last epoch trained:{self.start_epoch.numpy()}, Current step: {self.curr_step.numpy()}" ) if self.log: with open(osp.join(self.log_dir, "cfg.json"), "w") as f: json.dump(cfg, f, indent=4) self.ckpt_manager = tf.train.CheckpointManager( self.ckpt, self.model_dir, max_to_keep=3 ) self.prepare_dataset(self.cfg.train.data_dir) self.print_info() if self.cfg.train.generator.fixed_z: self.z_bg = sample_z(1, self.generator.z_dim_bg, num_objects=1) self.z_fg = sample_z(1, self.generator.z_dim_fg, num_objects=1) self.bg_view = sample_view(1, num_objects=1) self.fg_view = sample_view(1, num_objects=1) else: self.z_bg = self.z_fg = self.bg_view = self.fg_view = None def prepare_dataset(self, data_dir): self.data_dir = data_dir self.num_tr = len(glob.glob(osp.join(self.data_dir, "*.png"))) self.list_ds_train = tf.data.Dataset.list_files( os.path.join(self.data_dir, "*.png") ) self.labeled_ds = self.list_ds_train.map( lambda x: process_path( x, self.cfg.train.image_height, self.cfg.train.image_width ), num_parallel_calls=AUTOTUNE, ) self.steps_per_epoch = int(math.ceil(self.num_tr / self.cfg.train.batch_size)) def print_info(self): print("Using config:") pprint.pprint(self.cfg) print("\n") pprint.pprint("Size of train dataset: {}".format(self.num_tr)) print("\n") # lossess def discriminator_loss(self, real, generated): real_loss = self.bce(tf.ones_like(real), real) generated_loss = self.bce(tf.zeros_like(generated), generated) total_disc_loss = real_loss + generated_loss return total_disc_loss * 0.5 def generator_loss(self, generated): return self.bce(tf.ones_like(generated), generated) # def generate_random_noise(self, batch_size, num_objects=(3, 10)): # z_bg = tf.random.uniform( # (batch_size, self.generator.z_dim_bg), minval=-1, maxval=1 # ) # num_objs = tf.random.uniform( # (batch_size,), # minval=num_objects[0], # maxval=num_objects[1] + 1, # dtype=tf.int32, # ) # tensors = [] # max_len = max(num_objs) # for no in num_objs: # _t = tf.random.uniform((no, self.generator.z_dim_fg), minval=-1, maxval=1) # _z = tf.zeros((max_len - no, self.generator.z_dim_fg), dtype=tf.float32) # _t = tf.concat((_t, _z), axis=0) # tensors.append(_t) # z_fg = tf.stack(tensors, axis=0) # return z_bg, z_fg def batch_logits(self, image_batch, z_bg, z_fg, bg_view, fg_view): generated = self.generator(z_bg, z_fg, bg_view, fg_view) d_fake_logits = self.discriminator(generated, training=True) image_batch = (image_batch * 2) - 1 if self.cfg.train.discriminator.random_noise or self.curr_step <= 2000: image_batch = image_batch + tf.random.normal(image_batch.shape, stddev=0.01) d_real_logits = self.discriminator(image_batch, training=True,) return d_fake_logits, d_real_logits, generated # @tf.function def train_epoch(self, epoch): train_iter = prepare_for_training( self.labeled_ds, self.cfg.train.batch_size, cache=False, ) pbar = tqdm( enumerate(train_iter), total=self.steps_per_epoch, ncols=20, desc=f"Epoch {epoch}", mininterval=10, miniters=50, ) total_d_loss = 0.0 total_g_loss = 0.0 counter = 1 real_are_real_samples_counter = 0 real_samples_counter = 0 fake_are_fake_samples_counter = 0 fake_samples_counter = 0 z_bg = z_fg = None for it, image_batch in pbar: bsz = image_batch.shape[0] # generated random noise if self.z_bg is not None: # For overfitting one sample and debugging z_bg = tf.repeat(self.z_bg, bsz, axis=0) z_fg = tf.repeat(self.z_fg, bsz, axis=0) bg_view = self.bg_view fg_view = self.fg_view else: z_bg = sample_z(bsz, self.generator.z_dim_bg, num_objects=1) z_fg = sample_z( bsz, self.generator.z_dim_fg, num_objects=(3, min(10, 3 + 1 * (epoch // 2))), ) bg_view = sample_view( batch_size=bsz, num_objects=1, azimuth_range=(-20, 20), elevation_range=(-10, 10), scale_range=(0.9, 1.1), ) fg_view = sample_view(batch_size=bsz, num_objects=z_fg.shape[1]) with tf.GradientTape(persistent=True) as tape: # fake img d_fake_logits, d_real_logits, generated = self.batch_logits( image_batch, z_bg, z_fg, bg_view, fg_view ) d_loss = self.discriminator_loss(d_real_logits, d_fake_logits) g_loss = self.generator_loss(d_fake_logits) total_d_loss += d_loss.numpy() # total_g_loss += g_loss.numpy() / self.cfg.train.generator.update_freq total_g_loss += g_loss.numpy() d_variables = self.discriminator.trainable_variables d_gradients = tape.gradient(d_loss, d_variables) self.d_optimizer.apply_gradients(zip(d_gradients, d_variables)) g_variables = self.generator.trainable_variables g_gradients = tape.gradient(g_loss, g_variables) self.g_optimizer.apply_gradients(zip(g_gradients, g_variables)) del tape real_samples_counter += d_real_logits.shape[0] fake_samples_counter += d_fake_logits.shape[0] real_are_real = (d_real_logits >= 0).numpy().sum() real_are_real_samples_counter += real_are_real fake_are_fake = (d_fake_logits < 0).numpy().sum() fake_are_fake_samples_counter += fake_are_fake # according to paper generator makes 2 steps per each step of the disc # for _ in range(self.cfg.train.generator.update_freq - 1): # with tf.GradientTape(persistent=True) as tape: # # fake img # d_fake_logits, _, generated = self.batch_logits( # image_batch, z_bg, z_fg # ) # g_loss = self.generator_loss(d_fake_logits) # g_variables = self.generator.trainable_variables # g_gradients = tape.gradient(g_loss, g_variables) # self.g_optimizer.apply_gradients(zip(g_gradients, g_variables)) # total_g_loss += g_loss.numpy() / self.cfg.train.generator.update_freq pbar.set_postfix( g_loss=f"{g_loss.numpy():.2f} ({total_g_loss / (counter):.2f})", d_loss=f"{d_loss.numpy():.4f} ({total_d_loss / (counter):.4f})", rrr=f"{real_are_real / d_real_logits.shape[0]:.1f} ({real_are_real_samples_counter / real_samples_counter:.1f})", frf=f"{fake_are_fake / d_fake_logits.shape[0]:.1f} ({fake_are_fake_samples_counter / fake_samples_counter:.1f})", refresh=False, ) if it % (self.cfg.train.it_log_interval) == 0: self.log_training( d_loss=total_d_loss / counter, g_loss=total_g_loss / counter, real_are_real=real_are_real_samples_counter / real_samples_counter, fake_are_fake=fake_are_fake_samples_counter / fake_samples_counter, fake_images=(generated + 1) / 2, real_images=image_batch, d_fake_logits=d_fake_logits, d_real_logits=d_real_logits, epoch=epoch, it=it, ) real_are_real_samples_counter = 0 fake_are_fake_samples_counter = 0 real_samples_counter = 0 fake_samples_counter = 0 total_d_loss = 0.0 total_g_loss = 0.0 counter = 0 counter += 1 gc.collect() del train_iter def log_training( self, d_loss, g_loss, fake_images, real_images, d_fake_logits, d_real_logits, epoch, it, real_are_real, fake_are_fake, ): if self.log: curr_step = (self.curr_step + it).numpy() real_are_real_images, real_are_fake_images = split_images_on_disc( real_images, d_real_logits ) fake_are_real_images, fake_are_fake_images = split_images_on_disc( fake_images, d_fake_logits ) with self.summary_writer.as_default(): tf.summary.scalar( "losses/d_loss", d_loss, step=curr_step, description="Average of predicting real images as real and fake as fake", ) tf.summary.scalar( "losses/g_loss", g_loss, step=curr_step, description="Predicting fake images as real", ) tf.summary.scalar( "accuracy/real", real_are_real, step=curr_step, description="Real images classified as real", ) tf.summary.scalar( "accuracy/fake", fake_are_fake, step=curr_step, description="Fake images classified as fake", ) tf.summary.image( f"{epoch}-{curr_step}-fake/are_fake", fake_are_fake_images, max_outputs=25, step=curr_step, description="Fake images that the discriminator says are fake", ) tf.summary.image( f"{epoch}-{curr_step}-fake/are_real", fake_are_real_images, max_outputs=25, step=curr_step, description="Fake images that the discriminator says are real", ) tf.summary.image( f"{epoch}-{curr_step}-real/are_fake", real_are_fake_images, max_outputs=25, step=curr_step, description="Real images that the discriminator says are fake", ) tf.summary.image( f"{epoch}-{curr_step}-real/are_real", real_are_real_images, max_outputs=25, step=curr_step, description="Real images that the discriminator says are real", ) self.comet.log_metrics( {"d_loss": d_loss, "g_loss": g_loss}, step=curr_step, epoch=epoch ) fig = show_batch(fake_images, labels=disc_preds_to_label(d_fake_logits)) self.comet.log_figure( figure=fig, figure_name="" f"fake_{epoch}_{it}.jpg", step=curr_step, ) plt.close(fig) fig = show_batch(real_images, labels=disc_preds_to_label(d_real_logits)) self.comet.log_figure( figure=fig, figure_name="" f"real_{epoch}_{it}.jpg", step=curr_step, ) plt.close(fig) def train(self): print("Start training") for epoch in range(self.start_epoch.numpy(), self.cfg.train.epochs): with self.comet.train(): self.train_epoch(epoch) self.curr_step.assign_add(self.steps_per_epoch) self.start_epoch.assign_add(1) if self.log and (((epoch + 1) % self.cfg.train.snapshot_interval) == 0): self.ckpt_manager.save(epoch + 1) def save_model(self, epoch): pass