示例#1
0
def main() -> None:
    args = get_args()
    config = get_bunch_config_from_json(args.config)

    comet_experiment = Experiment(
        api_key=config.comet_api_key,
        project_name=config.comet_project_name,
        workspace=config.comet_workspace,
        disabled=not config.use_comet_experiments,
    )
    comet_experiment.set_name(config.experiment_name)
    comet_experiment.log_parameters(config)

    test_tweets = load_test_tweets(config.test_data_path)

    client = LanguageServiceClient()
    result = []
    predictions = np.zeros(len(test_tweets), dtype=np.int32)

    for i, tweet in enumerate(test_tweets):
        start_iter_timestamp = time.time()
        document = types.Document(
            type=enums.Document.Type.PLAIN_TEXT, content=tweet, language="en"
        )

        response = client.analyze_sentiment(document=document)
        response_dict = MessageToDict(response)
        result.append(response_dict)

        prediction_present = bool(response_dict["documentSentiment"])
        if prediction_present:
            # -1, 1 predictions
            predictions[i] = 2 * (response.document_sentiment.score > 0) - 1

        print("iteration", i, "took:", time.time() - start_iter_timestamp, "seconds")

    comet_experiment.log_asset_data(result, name="google_nlp_api_response.json")

    ids = np.arange(1, len(test_tweets) + 1).astype(np.int32)
    predictions_table = np.column_stack((ids, predictions))

    if comet_experiment.disabled:
        save_path = build_save_path(config)
        os.makedirs(save_path)

        formatted_predictions_table = pd.DataFrame(
            predictions_table, columns=["Id", "Prediction"], dtype=np.int32,
        )
        formatted_predictions_table.to_csv(
            os.path.join(save_path, "google_nlp_api_predictions.csv"), index=False
        )
    else:
        comet_experiment.log_table(
            filename="google_nlp_api_predictions.csv",
            tabular_data=predictions_table,
            headers=["Id", "Prediction"],
        )

    percentage_predicted = np.sum(predictions != 0) / predictions.shape[0]
    comet_experiment.log_metric(name="percentage predicted", value=percentage_predicted)
示例#2
0
def init_experiment(experiment: Experiment, dataset: Dataset):
    """
	Initializes an experiment by logging the template and the validation set ground truths if they have not already
	been logged.
	"""
    api_experiment = APIExperiment(previous_experiment=experiment.id)

    try:
        api_experiment.get_asset("datatap/template.json")
    except NotFound:
        experiment.log_asset_data([
            annotation.to_json()
            for annotation in dataset.stream_split("validation")
        ],
                                  name="datatap/validation/ground_truth.json")

        experiment.log_asset_data(dataset.template.to_json(),
                                  name="datatap/template.json")
示例#3
0
class Trainer():
    def __init__(self, log_dir, cfg):

        self.path = log_dir
        self.cfg = cfg

        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(self.path, 'Model')
            self.log_dir = os.path.join(self.path, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.log_dir)
            self.writer = SummaryWriter(log_dir=self.log_dir)
            self.logfile = os.path.join(self.path, "logfile.log")
            sys.stdout = Logger(logfile=self.logfile)

        self.data_dir = cfg.DATASET.DATA_DIR
        self.max_epochs = cfg.TRAIN.MAX_EPOCHS
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)

        self.batch_size = cfg.TRAIN.BATCH_SIZE
        self.lr = cfg.TRAIN.LEARNING_RATE

        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True

        sample = cfg.SAMPLE
        self.dataset = []
        self.dataloader = []
        self.use_feats = cfg.model.use_feats
        eval_split = cfg.EVAL if cfg.EVAL else 'val'
        train_split = cfg.DATASET.train_split
        if cfg.DATASET.DATASET == 'clevr':
            clevr_collate_fn = collate_fn
            cogent = cfg.DATASET.COGENT
            if cogent:
                print(f'Using CoGenT {cogent.upper()}')

            if cfg.TRAIN.FLAG:
                self.dataset = ClevrDataset(data_dir=self.data_dir,
                                            split=train_split + cogent,
                                            sample=sample,
                                            **cfg.DATASET.params)
                self.dataloader = DataLoader(dataset=self.dataset,
                                             batch_size=cfg.TRAIN.BATCH_SIZE,
                                             shuffle=True,
                                             num_workers=cfg.WORKERS,
                                             drop_last=True,
                                             collate_fn=clevr_collate_fn)

            self.dataset_val = ClevrDataset(data_dir=self.data_dir,
                                            split=eval_split + cogent,
                                            sample=sample,
                                            **cfg.DATASET.params)
            self.dataloader_val = DataLoader(dataset=self.dataset_val,
                                             batch_size=cfg.TEST_BATCH_SIZE,
                                             drop_last=False,
                                             shuffle=False,
                                             num_workers=cfg.WORKERS,
                                             collate_fn=clevr_collate_fn)

        elif cfg.DATASET.DATASET == 'gqa':
            if self.use_feats == 'spatial':
                gqa_collate_fn = collate_fn_gqa
            elif self.use_feats == 'objects':
                gqa_collate_fn = collate_fn_gqa_objs
            if cfg.TRAIN.FLAG:
                self.dataset = GQADataset(data_dir=self.data_dir,
                                          split=train_split,
                                          sample=sample,
                                          use_feats=self.use_feats,
                                          **cfg.DATASET.params)
                self.dataloader = DataLoader(dataset=self.dataset,
                                             batch_size=cfg.TRAIN.BATCH_SIZE,
                                             shuffle=True,
                                             num_workers=cfg.WORKERS,
                                             drop_last=True,
                                             collate_fn=gqa_collate_fn)

            self.dataset_val = GQADataset(data_dir=self.data_dir,
                                          split=eval_split,
                                          sample=sample,
                                          use_feats=self.use_feats,
                                          **cfg.DATASET.params)
            self.dataloader_val = DataLoader(dataset=self.dataset_val,
                                             batch_size=cfg.TEST_BATCH_SIZE,
                                             shuffle=False,
                                             num_workers=cfg.WORKERS,
                                             drop_last=False,
                                             collate_fn=gqa_collate_fn)

        # load model
        self.vocab = load_vocab(cfg)
        self.model, self.model_ema = mac.load_MAC(cfg, self.vocab)

        self.weight_moving_average(alpha=0)
        if cfg.TRAIN.RADAM:
            self.optimizer = RAdam(self.model.parameters(), lr=self.lr)
        else:
            self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)

        self.start_epoch = 0
        if cfg.resume_model:
            location = 'cuda' if cfg.CUDA else 'cpu'
            state = torch.load(cfg.resume_model, map_location=location)
            self.model.load_state_dict(state['model'])
            self.optimizer.load_state_dict(state['optim'])
            self.start_epoch = state['iter'] + 1
            state = torch.load(cfg.resume_model_ema, map_location=location)
            self.model_ema.load_state_dict(state['model'])

        if cfg.start_epoch is not None:
            self.start_epoch = cfg.start_epoch

        self.previous_best_acc = 0.0
        self.previous_best_epoch = 0
        self.previous_best_loss = 100
        self.previous_best_loss_epoch = 0

        self.total_epoch_loss = 0
        self.prior_epoch_loss = 10

        self.print_info()
        self.loss_fn = torch.nn.CrossEntropyLoss().cuda()

        self.comet_exp = Experiment(
            project_name=cfg.COMET_PROJECT_NAME,
            api_key=os.getenv('COMET_API_KEY'),
            workspace=os.getenv('COMET_WORKSPACE'),
            disabled=cfg.logcomet is False,
        )
        if cfg.logcomet:
            exp_name = cfg_to_exp_name(cfg)
            print(exp_name)
            self.comet_exp.set_name(exp_name)
            self.comet_exp.log_parameters(flatten_json_iterative_solution(cfg))
            self.comet_exp.log_asset(self.logfile)
            self.comet_exp.log_asset_data(json.dumps(cfg, indent=4),
                                          file_name='cfg.json')
            self.comet_exp.set_model_graph(str(self.model))
            if cfg.cfg_file:
                self.comet_exp.log_asset(cfg.cfg_file)

        with open(os.path.join(self.path, 'cfg.json'), 'w') as f:
            json.dump(cfg, f, indent=4)

    def print_info(self):
        print('Using config:')
        pprint.pprint(self.cfg)
        print("\n")

        pprint.pprint("Size of train dataset: {}".format(len(self.dataset)))
        # print("\n")
        pprint.pprint("Size of val dataset: {}".format(len(self.dataset_val)))
        print("\n")

        print("Using MAC-Model:")
        pprint.pprint(self.model)
        print("\n")

    def weight_moving_average(self, alpha=0.999):
        for param1, param2 in zip(self.model_ema.parameters(),
                                  self.model.parameters()):
            param1.data *= alpha
            param1.data += (1.0 - alpha) * param2.data

    def set_mode(self, mode="train"):
        if mode == "train":
            self.model.train()
            self.model_ema.train()
        else:
            self.model.eval()
            self.model_ema.eval()

    def reduce_lr(self):
        epoch_loss = self.total_epoch_loss  # / float(len(self.dataset) // self.batch_size)
        lossDiff = self.prior_epoch_loss - epoch_loss
        if ((lossDiff < 0.015 and self.prior_epoch_loss < 0.5 and self.lr > 0.00002) or \
            (lossDiff < 0.008 and self.prior_epoch_loss < 0.15 and self.lr > 0.00001) or \
            (lossDiff < 0.003 and self.prior_epoch_loss < 0.10 and self.lr > 0.000005)):
            self.lr *= 0.5
            print("Reduced learning rate to {}".format(self.lr))
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = self.lr
        self.prior_epoch_loss = epoch_loss
        self.total_epoch_loss = 0

    def save_models(self, iteration):
        save_model(self.model,
                   self.optimizer,
                   iteration,
                   self.model_dir,
                   model_name="model")
        save_model(self.model_ema,
                   None,
                   iteration,
                   self.model_dir,
                   model_name="model_ema")

    def train_epoch(self, epoch):
        cfg = self.cfg
        total_loss = 0.
        total_correct = 0
        total_samples = 0

        self.labeled_data = iter(self.dataloader)
        self.set_mode("train")

        dataset = tqdm(self.labeled_data, total=len(self.dataloader), ncols=20)

        for data in dataset:
            ######################################################
            # (1) Prepare training data
            ######################################################
            image, question, question_len, answer = data['image'], data[
                'question'], data['question_length'], data['answer']
            answer = answer.long()
            question = Variable(question)
            answer = Variable(answer)

            if cfg.CUDA:
                if self.use_feats == 'spatial':
                    image = image.cuda()
                elif self.use_feats == 'objects':
                    image = [e.cuda() for e in image]
                question = question.cuda()
                answer = answer.cuda().squeeze()
            else:
                question = question
                image = image
                answer = answer.squeeze()

            ############################
            # (2) Train Model
            ############################
            self.optimizer.zero_grad()

            scores = self.model(image, question, question_len)
            loss = self.loss_fn(scores, answer)
            loss.backward()

            if self.cfg.TRAIN.CLIP_GRADS:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.cfg.TRAIN.CLIP)

            self.optimizer.step()
            self.weight_moving_average()

            ############################
            # (3) Log Progress
            ############################
            correct = scores.detach().argmax(1) == answer
            total_correct += correct.sum().cpu().item()
            total_loss += loss.item() * answer.size(0)
            total_samples += answer.size(0)

            avg_loss = total_loss / total_samples
            train_accuracy = total_correct / total_samples
            # accuracy = correct.sum().cpu().numpy() / answer.shape[0]

            # if avg_loss == 0:
            #     avg_loss = loss.item()
            #     train_accuracy = accuracy
            # else:
            #     avg_loss = 0.99 * avg_loss + 0.01 * loss.item()
            #     train_accuracy = 0.99 * train_accuracy + 0.01 * accuracy
            # self.total_epoch_loss += loss.item() * answer.size(0)

            dataset.set_description(
                'Epoch: {}; Avg Loss: {:.5f}; Avg Train Acc: {:.5f}'.format(
                    epoch + 1, avg_loss, train_accuracy))

        self.total_epoch_loss = avg_loss

        dict = {
            "loss": avg_loss,
            "accuracy": train_accuracy,
            "avg_loss": avg_loss,  # For commet
            "avg_accuracy": train_accuracy,  # For commet
        }
        return dict

    def train(self):
        cfg = self.cfg
        print("Start Training")
        for epoch in range(self.start_epoch, self.max_epochs):

            with self.comet_exp.train():
                dict = self.train_epoch(epoch)
                self.reduce_lr()
                dict['epoch'] = epoch + 1
                dict['lr'] = self.lr
                self.comet_exp.log_metrics(
                    dict,
                    epoch=epoch + 1,
                )

            with self.comet_exp.validate():
                dict = self.log_results(epoch, dict)
                dict['epoch'] = epoch + 1
                dict['lr'] = self.lr
                self.comet_exp.log_metrics(
                    dict,
                    epoch=epoch + 1,
                )

            if cfg.TRAIN.EALRY_STOPPING:
                if epoch - cfg.TRAIN.PATIENCE == self.previous_best_epoch:
                    # if epoch - cfg.TRAIN.PATIENCE == self.previous_best_loss_epoch:
                    print('Early stop')
                    break

        self.comet_exp.log_asset(self.logfile)
        self.save_models(self.max_epochs)
        self.writer.close()
        print("Finished Training")
        print(
            f"Highest validation accuracy: {self.previous_best_acc} at epoch {self.previous_best_epoch}"
        )

    def log_results(self, epoch, dict, max_eval_samples=None):
        epoch += 1
        self.writer.add_scalar("avg_loss", dict["loss"], epoch)
        self.writer.add_scalar("train_accuracy", dict["accuracy"], epoch)

        metrics = self.calc_accuracy("validation",
                                     max_samples=max_eval_samples)
        self.writer.add_scalar("val_accuracy_ema", metrics['acc_ema'], epoch)
        self.writer.add_scalar("val_accuracy", metrics['acc'], epoch)
        self.writer.add_scalar("val_loss_ema", metrics['loss_ema'], epoch)
        self.writer.add_scalar("val_loss", metrics['loss'], epoch)

        print(
            "Epoch: {epoch}\tVal Acc: {acc},\tVal Acc EMA: {acc_ema},\tAvg Loss: {loss},\tAvg Loss EMA: {loss_ema},\tLR: {lr}"
            .format(epoch=epoch, lr=self.lr, **metrics))

        if metrics['acc'] > self.previous_best_acc:
            self.previous_best_acc = metrics['acc']
            self.previous_best_epoch = epoch
        if metrics['loss'] < self.previous_best_loss:
            self.previous_best_loss = metrics['loss']
            self.previous_best_loss_epoch = epoch

        if epoch % self.snapshot_interval == 0:
            self.save_models(epoch)

        return metrics

    def calc_accuracy(self, mode="train", max_samples=None):
        self.set_mode("validation")

        if mode == "train":
            loader = self.dataloader
        # elif (mode == "validation") or (mode == 'test'):
        #     loader = self.dataloader_val
        else:
            loader = self.dataloader_val

        total_correct = 0
        total_correct_ema = 0
        total_samples = 0
        total_loss = 0.
        total_loss_ema = 0.
        pbar = tqdm(loader, total=len(loader), desc=mode.upper(), ncols=20)
        for data in pbar:

            image, question, question_len, answer = data['image'], data[
                'question'], data['question_length'], data['answer']
            answer = answer.long()
            question = Variable(question)
            answer = Variable(answer)

            if self.cfg.CUDA:
                if self.use_feats == 'spatial':
                    image = image.cuda()
                elif self.use_feats == 'objects':
                    image = [e.cuda() for e in image]
                question = question.cuda()
                answer = answer.cuda().squeeze()

            with torch.no_grad():
                scores = self.model(image, question, question_len)
                scores_ema = self.model_ema(image, question, question_len)

                loss = self.loss_fn(scores, answer)
                loss_ema = self.loss_fn(scores_ema, answer)

            correct = scores.detach().argmax(1) == answer
            correct_ema = scores_ema.detach().argmax(1) == answer

            total_correct += correct.sum().cpu().item()
            total_correct_ema += correct_ema.sum().cpu().item()

            total_loss += loss.item() * answer.size(0)
            total_loss_ema += loss_ema.item() * answer.size(0)

            total_samples += answer.size(0)

            avg_acc = total_correct / total_samples
            avg_acc_ema = total_correct_ema / total_samples
            avg_loss = total_loss / total_samples
            avg_loss_ema = total_loss_ema / total_samples

            pbar.set_postfix({
                'Acc': f'{avg_acc:.5f}',
                'Acc Ema': f'{avg_acc_ema:.5f}',
                'Loss': f'{avg_loss:.5f}',
                'Loss Ema': f'{avg_loss_ema:.5f}',
            })

        return dict(acc=avg_acc,
                    acc_ema=avg_acc_ema,
                    loss=avg_loss,
                    loss_ema=avg_loss_ema)
示例#4
0
    print('Model loaded.')

    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    log_dict = helpers.flatten_dict(config)
    log_dict.update({'trainable_params': n_params})
    exp.log_parameters(log_dict)

    test_dataset = data.CSVDatasetsMerger(helpers.get_datasets_paths(config, 'test'))
    test_dataloader = DataLoader(test_dataset,
                       batch_size=config['evaluation']['eval_batch_size'],
                       shuffle=False,
                       drop_last=False,
                       num_workers=config['evaluation']['n_eval_workers'],
                       collate_fn=text_proc)

    evaluator = Evaluation(test_dataloader, config)

    print('Testing ...')
    results, assets, image_fns = evaluator.eval_model(model, finished_training=True)
    print('Finished testing. Uploading ...')

    exp.log_metrics(results, step=0, epoch=0)
    [exp.log_asset_data(asset, step=0) for asset in assets]
    [exp.log_image(fn, step=0) for fn in image_fns]

    print('Finished uploading.')




示例#5
0
def log_validation_proposals(experiment: Experiment,
                             proposals: Sequence[ImageAnnotation]):
    experiment.log_asset_data(
        [annotation.to_json() for annotation in proposals],
        name="datatap/validation/proposals.json")
示例#6
0
class Plotter:
    """
    Handles plotting and logging to comet.

    Args:
        exp_args (args.parse_args): arguments for the experiment
        agent_args (dict): arguments for the agent
        agent (Agent): the agent
    """
    def __init__(self, exp_args, agent_args, agent):
        self.exp_args = exp_args
        self.agent_args = agent_args
        self.agent = agent
        self.experiment = None
        if self.exp_args.plotting:
            self.experiment = Experiment(api_key=LOGGING_API_KEY,
                                         project_name=PROJECT_NAME,
                                         workspace=WORKSPACE)
            self.experiment.disable_mp()
            self.experiment.log_parameters(get_arg_dict(exp_args))
            self.experiment.log_parameters(flatten_arg_dict(agent_args))
            self.experiment.log_asset_data(json.dumps(get_arg_dict(exp_args)), name='exp_args')
            self.experiment.log_asset_data(json.dumps(agent_args), name='agent_args')
            if self.exp_args.checkpoint_exp_key is not None:
                self.load_checkpoint()
        self.result_dict = None
        # keep a hard-coded list of returns in case Comet fails
        self.returns = []

    def _plot_ts(self, key, observations, statistics, label, color):
        dim_obs = min(observations.shape[1], 9)
        k = 1
        for i in range(dim_obs):
            plt.subplot(int(str(dim_obs) + '1' + str(k)))
            observations_i = observations[:-1, i].cpu().numpy()
            if key == 'action' and self.agent.postprocess_action:
                observations_i = np.tanh(observations_i)
            plt.plot(observations_i.squeeze(), 'o', label='observation', color='k', markersize=2)
            if len(statistics) == 1:  # Bernoulli distribution
                probs = statistics['probs']
                probs = probs.cpu().numpy()
                plt.plot(probs, label=label, color=color)
            elif len(statistics) == 2:
                if 'loc' in statistics:
                    # Normal distribution
                    mean = statistics['loc']
                    std = statistics['scale']
                    mean = mean[:, i].cpu().numpy()
                    std = std[:, i].cpu().numpy()
                    mean = mean.squeeze()
                    std = std.squeeze()
                    x, plus, minus = mean, mean + std, mean - std
                    if key == 'action' and label == 'approx_post' and self.agent_args['approx_post_args']['dist_type'] in ['TanhNormal', 'TanhARNormal']:
                        # Tanh Normal distribution
                        x, plus, minus = np.tanh(x), np.tanh(plus), np.tanh(minus)
                    if key == 'action' and label == 'direct_approx_post' and self.agent_args['approx_post_args']['dist_type'] in ['TanhNormal', 'TanhARNormal']:
                        # Tanh Normal distribution
                        x, plus, minus = np.tanh(x), np.tanh(plus), np.tanh(minus)
                    if key == 'action' and label == 'prior' and self.agent_args['prior_args']['dist_type'] in ['TanhNormal', 'TanhARNormal']:
                        # Tanh Normal distribution
                        x, plus, minus = np.tanh(x), np.tanh(plus), np.tanh(minus)
                    if key == 'action' and self.agent.postprocess_action:
                        x, plus, minus = np.tanh(x), np.tanh(plus), np.tanh(minus)
                    if key == 'action' and label == 'prior' and self.agent_args['prior_args']['dist_type'] == 'NormalUniform':
                        # Normal + Uniform distribution
                        x, plus, minus = x, np.minimum(plus, 1.), np.maximum(minus, -1)
                elif 'low' in statistics:
                    # Uniform distribution
                    low = statistics['low'][:, i].cpu().numpy()
                    high = statistics['high'][:, i].cpu().numpy()
                    x = low + (high - low) / 2
                    plus, minus = x + high, x + low
                else:
                    raise NotImplementedError
                plt.plot(x, label=label, color=color)
                plt.fill_between(np.arange(len(x)), plus, minus, color=color, alpha=0.2, label=label)
            else:
                NotImplementedError
            k += 1

    def plot_states_and_rewards(self, states, rewards, step):
        """
        Plots the states and rewards for a collected episode.
        """
        # states
        plt.figure()
        dim_obs = states.shape[1]
        for i in range(dim_obs):
            plt.subplot(dim_obs, 1, i+1)
            states_i = states[:-1, i].cpu().numpy()
            plt.plot(states_i.squeeze(), 'o', label='state', color='k', markersize=2)
        self.experiment.log_figure(figure=plt, figure_name='states_ts_'+str(step))
        plt.close()

        # rewards
        plt.figure()
        rewards = rewards[:-1, 0].cpu().numpy()
        plt.plot(rewards.squeeze(), 'o', label='reward', color='k', markersize=2)
        self.experiment.log_figure(figure=plt, figure_name='rewards_ts_'+str(step))
        plt.close()

    def plot_episode(self, episode, step):
        """
        Plots a newly collected episode.
        """
        if self.exp_args.plotting:
            self.experiment.log_metric('cumulative_reward', episode['reward'].sum(), step)

            def merge_legends():
                handles, labels = plt.gca().get_legend_handles_labels()
                newLabels, newHandles = [], []
                for handle, label in zip(handles, labels):
                    if label not in newLabels:
                        newLabels.append(label)
                        newHandles.append(handle)

                plt.legend(newHandles, newLabels)

            for k in episode['distributions'].keys():
                for i, l in enumerate(episode['distributions'][k].keys()):
                    color = COLORS[i]
                    self._plot_ts(k, episode[k], episode['distributions'][k][l], l, color)
                plt.suptitle(k)
                merge_legends()
                self.experiment.log_figure(figure=plt, figure_name=k + '_ts_'+str(step))
                plt.close()

            self.plot_states_and_rewards(episode['state'], episode['reward'], step)

    def log_eval(self, episode, eval_states, step):
        """
        Plots an evaluation episode performance. Logs the episode.

        Args:
            episode (dict): dictionary containing agent's collected episode
            eval_states (dict): dictionary of MuJoCo simulator states
            step (int): the current step number in training
        """
        # plot and log eval returns
        eval_return = episode['reward'].sum()
        print(' Eval. Return at Step ' + str(step) + ': ' + str(eval_return.item()))
        self.returns.append(eval_return.item())
        if self.exp_args.plotting:
            self.experiment.log_metric('eval_cumulative_reward', eval_return, step)
            json_str = json.dumps(self.returns)
            self.experiment.log_asset_data(json_str, name='eval_returns', overwrite=True)

            # log the episode itself
            for ep_item_str in ['state', 'action', 'reward']:
                ep_item = episode[ep_item_str].tolist()
                json_str = json.dumps(ep_item)
                item_name = 'episode_step_' + str(step) + '_' + ep_item_str
                self.experiment.log_asset_data(json_str, name=item_name)

            # log the MuJoCo simulator states
            for sim_item_str in ['qpos', 'qvel']:
                if len(eval_states[sim_item_str]) > 0:
                    sim_item = eval_states[sim_item_str].tolist()
                    json_str = json.dumps(sim_item)
                    item_name = 'episode_step_' + str(step) + '_' + sim_item_str
                    self.experiment.log_asset_data(json_str, name=item_name)

    def plot_agent_kl(self, agent_kl, step):
        if self.exp_args.plotting:
            self.experiment.log_metric('agent_kl', agent_kl, step)

    def log_results(self, results):
        """
        Log the results dictionary.
        """
        if self.result_dict is None:
            self.result_dict = {}
        for k, v in flatten_arg_dict(results).items():
            if k not in self.result_dict:
                self.result_dict[k] = [v]
            else:
                self.result_dict[k].append(v)

    def plot_results(self, timestep):
        """
        Plot/log the results to Comet.
        """
        if self.exp_args.plotting:
            for k, v in self.result_dict.items():
                avg_value = np.mean(v)
                self.experiment.log_metric(k, avg_value, timestep)
        self.result_dict = None

    def plot_model_eval(self, episode, predictions, log_likelihoods, step):
        """
        Plot/log the results from model evaluation.

        Args:
            episode (dict): a collected episode
            predictions (dict): predictions from each state, containing [n_steps, horizon, n_dims]
            log_likelihoods (dict): log-likelihood evaluations of predictions, containing [n_steps, horizon, 1]
        """
        if self.exp_args.plotting:
            for variable, lls in log_likelihoods.items():
                # average the log-likelihood estimates and plot the result at the horizon length
                mean_ll = lls[:, -1].mean().item()
                self.experiment.log_metric(variable + '_pred_log_likelihood', mean_ll, step)
                # plot log-likelihoods as a function of rollout step
                plt.figure()
                mean = lls.mean(dim=0).view(-1)
                std = lls.std(dim=0).view(-1)
                plt.plot(mean.numpy())
                lower = mean - std
                upper = mean + std
                plt.fill_between(np.arange(lls.shape[1]), lower.numpy(), upper.numpy(), alpha=0.2)
                plt.xlabel('Rollout Step')
                plt.ylabel('Prediction Log-Likelihood')
                plt.xticks(np.arange(lls.shape[1]))
                self.experiment.log_figure(figure=plt, figure_name=variable + '_pred_ll_' + str(step))
                plt.close()

            # plot predictions vs. actual values for an arbitrary time step
            time_step = np.random.randint(predictions['state']['loc'].shape[0])
            for variable, preds in predictions.items():
                pred_loc = preds['loc'][time_step]
                pred_scale = preds['scale'][time_step]
                x = episode[variable][time_step+1:time_step+1+pred_loc.shape[0]]
                plt.figure()
                horizon, n_dims = pred_loc.shape
                for plot_num in range(n_dims):
                    plt.subplot(n_dims, 1, plot_num + 1)
                    plt.plot(pred_loc[:, plot_num].numpy())
                    lower = pred_loc[:, plot_num] - pred_scale[:, plot_num]
                    upper = pred_loc[:, plot_num] + pred_scale[:, plot_num]
                    plt.fill_between(np.arange(horizon), lower.numpy(), upper.numpy(), alpha=0.2)
                    plt.plot(x[:, plot_num].numpy(), '.')
                plt.xlabel('Rollout Step')
                plt.xticks(np.arange(horizon))
                self.experiment.log_figure(figure=plt, figure_name=variable + '_pred_' + str(step))
                plt.close()

    def save_checkpoint(self, step):
        """
        Checkpoint the model by getting the state dictionary for each component.
        """
        if self.exp_args.plotting:
            print('Checkpointing the agent...')
            state_dict = self.agent.state_dict()
            cpu_state_dict = {k: v.cpu() for k, v in state_dict.items()}
            ckpt_path = os.path.join('./ckpt_step_'+ str(step) + '.ckpt')
            torch.save(cpu_state_dict, ckpt_path)
            self.experiment.log_asset(ckpt_path)
            os.remove(ckpt_path)
            print('Done.')

    def load_checkpoint(self, timestep=None):
        """
        Loads a checkpoint from Comet.

        Args:
            timestep (int, optional): the checkpoint timestep, default is latest
        """
        load_checkpoint(self.agent, self.exp_args.checkpoint_exp_key, timestep)
示例#7
0
class Trainer:
    def __init__(self, cfg, log_dir):

        self.log_dir = log_dir
        self.cfg = cfg

        if log_dir:
            self.log = True
            self.model_dir = osp.join(log_dir, "checkpoints")
            mkdir_p(self.model_dir)
            self.logfile = osp.join(log_dir, "logfile.log")
            sys.stdout = Logger(logfile=self.logfile)
            self.summary_writer = tf.summary.create_file_writer(log_dir)
        else:
            self.log = False

        self.generator = Generator(**cfg.model.generator)
        self.discriminator = Discriminator(**cfg.model.discriminator)

        self.g_optimizer = optimizers.Adam(**cfg.train.generator.optimizer)
        self.d_optimizer = optimizers.Adam(**cfg.train.discriminator.optimizer)

        self.bce = losses.BinaryCrossentropy(from_logits=True)

        # TODO resume model

        self.comet = Experiment(
            api_key=os.getenv("COMET_API_KEY"),
            workspace=os.getenv("COMET_WORKSPACE"),
            project_name=cfg.comet_project_name,
            disabled=cfg.logcomet is False or not self.log,
        )
        self.comet.set_name(f"{cfg.experiment_name}/{cfg.run_name}")
        self.comet.log_parameters(flatten_json_iterative_solution(self.cfg))
        self.comet.log_asset_data(json.dumps(self.cfg, indent=4), file_name="cfg.json")
        if cfg.cfg_file:
            self.comet.log_asset(cfg.cfg_file)

        self.start_epoch = tf.Variable(0)
        self.curr_step = tf.Variable(0)
        self.ckpt = tf.train.Checkpoint(
            generator=self.generator,
            discriminator=self.discriminator,
            g_optimizer=self.g_optimizer,
            d_optimizer=self.d_optimizer,
            start_epoch=self.start_epoch,
            curr_step=self.curr_step,
        )
        if cfg.train.resume:
            ckpt_resumer = tf.train.CheckpointManager(
                self.ckpt, cfg.train.resume, max_to_keep=3,
            )
            # if a checkpoint exists, restore the latest checkpoint.
            if ckpt_resumer.latest_checkpoint:
                self.ckpt.restore(ckpt_resumer.latest_checkpoint)
                print("Latest checkpoint restored!!", ckpt_resumer.latest_checkpoint)
                print(
                    f"Last epoch trained:{self.start_epoch.numpy()}, Current step: {self.curr_step.numpy()}"
                )
        if self.log:
            with open(osp.join(self.log_dir, "cfg.json"), "w") as f:
                json.dump(cfg, f, indent=4)
            self.ckpt_manager = tf.train.CheckpointManager(
                self.ckpt, self.model_dir, max_to_keep=3
            )

        self.prepare_dataset(self.cfg.train.data_dir)
        self.print_info()

        if self.cfg.train.generator.fixed_z:
            self.z_bg = sample_z(1, self.generator.z_dim_bg, num_objects=1)
            self.z_fg = sample_z(1, self.generator.z_dim_fg, num_objects=1)
            self.bg_view = sample_view(1, num_objects=1)
            self.fg_view = sample_view(1, num_objects=1)
        else:
            self.z_bg = self.z_fg = self.bg_view = self.fg_view = None

    def prepare_dataset(self, data_dir):
        self.data_dir = data_dir
        self.num_tr = len(glob.glob(osp.join(self.data_dir, "*.png")))
        self.list_ds_train = tf.data.Dataset.list_files(
            os.path.join(self.data_dir, "*.png")
        )
        self.labeled_ds = self.list_ds_train.map(
            lambda x: process_path(
                x, self.cfg.train.image_height, self.cfg.train.image_width
            ),
            num_parallel_calls=AUTOTUNE,
        )
        self.steps_per_epoch = int(math.ceil(self.num_tr / self.cfg.train.batch_size))

    def print_info(self):
        print("Using config:")
        pprint.pprint(self.cfg)
        print("\n")
        pprint.pprint("Size of train dataset: {}".format(self.num_tr))
        print("\n")

    # lossess
    def discriminator_loss(self, real, generated):
        real_loss = self.bce(tf.ones_like(real), real)
        generated_loss = self.bce(tf.zeros_like(generated), generated)
        total_disc_loss = real_loss + generated_loss

        return total_disc_loss * 0.5

    def generator_loss(self, generated):
        return self.bce(tf.ones_like(generated), generated)

    # def generate_random_noise(self, batch_size, num_objects=(3, 10)):
    #     z_bg = tf.random.uniform(
    #         (batch_size, self.generator.z_dim_bg), minval=-1, maxval=1
    #     )
    #     num_objs = tf.random.uniform(
    #         (batch_size,),
    #         minval=num_objects[0],
    #         maxval=num_objects[1] + 1,
    #         dtype=tf.int32,
    #     )
    #     tensors = []
    #     max_len = max(num_objs)
    #     for no in num_objs:
    #         _t = tf.random.uniform((no, self.generator.z_dim_fg), minval=-1, maxval=1)
    #         _z = tf.zeros((max_len - no, self.generator.z_dim_fg), dtype=tf.float32)
    #         _t = tf.concat((_t, _z), axis=0)
    #         tensors.append(_t)
    #     z_fg = tf.stack(tensors, axis=0)

    #     return z_bg, z_fg

    def batch_logits(self, image_batch, z_bg, z_fg, bg_view, fg_view):
        generated = self.generator(z_bg, z_fg, bg_view, fg_view)

        d_fake_logits = self.discriminator(generated, training=True)
        image_batch = (image_batch * 2) - 1
        if self.cfg.train.discriminator.random_noise or self.curr_step <= 2000:
            image_batch = image_batch + tf.random.normal(image_batch.shape, stddev=0.01)
        d_real_logits = self.discriminator(image_batch, training=True,)

        return d_fake_logits, d_real_logits, generated

    # @tf.function
    def train_epoch(self, epoch):
        train_iter = prepare_for_training(
            self.labeled_ds, self.cfg.train.batch_size, cache=False,
        )
        pbar = tqdm(
            enumerate(train_iter),
            total=self.steps_per_epoch,
            ncols=20,
            desc=f"Epoch {epoch}",
            mininterval=10,
            miniters=50,
        )
        total_d_loss = 0.0
        total_g_loss = 0.0
        counter = 1
        real_are_real_samples_counter = 0
        real_samples_counter = 0
        fake_are_fake_samples_counter = 0
        fake_samples_counter = 0

        z_bg = z_fg = None
        for it, image_batch in pbar:
            bsz = image_batch.shape[0]
            # generated random noise
            if self.z_bg is not None:
                # For overfitting one sample and debugging
                z_bg = tf.repeat(self.z_bg, bsz, axis=0)
                z_fg = tf.repeat(self.z_fg, bsz, axis=0)
                bg_view = self.bg_view
                fg_view = self.fg_view
            else:
                z_bg = sample_z(bsz, self.generator.z_dim_bg, num_objects=1)
                z_fg = sample_z(
                    bsz,
                    self.generator.z_dim_fg,
                    num_objects=(3, min(10, 3 + 1 * (epoch // 2))),
                )
                bg_view = sample_view(
                    batch_size=bsz,
                    num_objects=1,
                    azimuth_range=(-20, 20),
                    elevation_range=(-10, 10),
                    scale_range=(0.9, 1.1),
                )
                fg_view = sample_view(batch_size=bsz, num_objects=z_fg.shape[1])

            with tf.GradientTape(persistent=True) as tape:
                # fake img
                d_fake_logits, d_real_logits, generated = self.batch_logits(
                    image_batch, z_bg, z_fg, bg_view, fg_view
                )
                d_loss = self.discriminator_loss(d_real_logits, d_fake_logits)
                g_loss = self.generator_loss(d_fake_logits)

            total_d_loss += d_loss.numpy()
            # total_g_loss += g_loss.numpy() / self.cfg.train.generator.update_freq
            total_g_loss += g_loss.numpy()

            d_variables = self.discriminator.trainable_variables
            d_gradients = tape.gradient(d_loss, d_variables)
            self.d_optimizer.apply_gradients(zip(d_gradients, d_variables))

            g_variables = self.generator.trainable_variables
            g_gradients = tape.gradient(g_loss, g_variables)
            self.g_optimizer.apply_gradients(zip(g_gradients, g_variables))

            del tape

            real_samples_counter += d_real_logits.shape[0]
            fake_samples_counter += d_fake_logits.shape[0]

            real_are_real = (d_real_logits >= 0).numpy().sum()
            real_are_real_samples_counter += real_are_real

            fake_are_fake = (d_fake_logits < 0).numpy().sum()
            fake_are_fake_samples_counter += fake_are_fake

            # according to paper generator makes 2 steps per each step of the disc
            # for _ in range(self.cfg.train.generator.update_freq - 1):
            #     with tf.GradientTape(persistent=True) as tape:
            #         # fake img
            #         d_fake_logits, _, generated = self.batch_logits(
            #             image_batch, z_bg, z_fg
            #         )
            #         g_loss = self.generator_loss(d_fake_logits)
            #     g_variables = self.generator.trainable_variables
            #     g_gradients = tape.gradient(g_loss, g_variables)
            #     self.g_optimizer.apply_gradients(zip(g_gradients, g_variables))
            #     total_g_loss += g_loss.numpy() / self.cfg.train.generator.update_freq

            pbar.set_postfix(
                g_loss=f"{g_loss.numpy():.2f} ({total_g_loss / (counter):.2f})",
                d_loss=f"{d_loss.numpy():.4f} ({total_d_loss / (counter):.4f})",
                rrr=f"{real_are_real / d_real_logits.shape[0]:.1f} ({real_are_real_samples_counter / real_samples_counter:.1f})",
                frf=f"{fake_are_fake / d_fake_logits.shape[0]:.1f} ({fake_are_fake_samples_counter / fake_samples_counter:.1f})",
                refresh=False,
            )

            if it % (self.cfg.train.it_log_interval) == 0:
                self.log_training(
                    d_loss=total_d_loss / counter,
                    g_loss=total_g_loss / counter,
                    real_are_real=real_are_real_samples_counter / real_samples_counter,
                    fake_are_fake=fake_are_fake_samples_counter / fake_samples_counter,
                    fake_images=(generated + 1) / 2,
                    real_images=image_batch,
                    d_fake_logits=d_fake_logits,
                    d_real_logits=d_real_logits,
                    epoch=epoch,
                    it=it,
                )
                real_are_real_samples_counter = 0
                fake_are_fake_samples_counter = 0
                real_samples_counter = 0
                fake_samples_counter = 0
                total_d_loss = 0.0
                total_g_loss = 0.0
                counter = 0

            counter += 1
            gc.collect()

        del train_iter

    def log_training(
        self,
        d_loss,
        g_loss,
        fake_images,
        real_images,
        d_fake_logits,
        d_real_logits,
        epoch,
        it,
        real_are_real,
        fake_are_fake,
    ):
        if self.log:
            curr_step = (self.curr_step + it).numpy()
            real_are_real_images, real_are_fake_images = split_images_on_disc(
                real_images, d_real_logits
            )
            fake_are_real_images, fake_are_fake_images = split_images_on_disc(
                fake_images, d_fake_logits
            )
            with self.summary_writer.as_default():
                tf.summary.scalar(
                    "losses/d_loss",
                    d_loss,
                    step=curr_step,
                    description="Average of predicting real images as real and fake as fake",
                )
                tf.summary.scalar(
                    "losses/g_loss",
                    g_loss,
                    step=curr_step,
                    description="Predicting fake images as real",
                )
                tf.summary.scalar(
                    "accuracy/real",
                    real_are_real,
                    step=curr_step,
                    description="Real images classified as real",
                )
                tf.summary.scalar(
                    "accuracy/fake",
                    fake_are_fake,
                    step=curr_step,
                    description="Fake images classified as fake",
                )
                tf.summary.image(
                    f"{epoch}-{curr_step}-fake/are_fake",
                    fake_are_fake_images,
                    max_outputs=25,
                    step=curr_step,
                    description="Fake images that the discriminator says are fake",
                )
                tf.summary.image(
                    f"{epoch}-{curr_step}-fake/are_real",
                    fake_are_real_images,
                    max_outputs=25,
                    step=curr_step,
                    description="Fake images that the discriminator says are real",
                )
                tf.summary.image(
                    f"{epoch}-{curr_step}-real/are_fake",
                    real_are_fake_images,
                    max_outputs=25,
                    step=curr_step,
                    description="Real images that the discriminator says are fake",
                )
                tf.summary.image(
                    f"{epoch}-{curr_step}-real/are_real",
                    real_are_real_images,
                    max_outputs=25,
                    step=curr_step,
                    description="Real images that the discriminator says are real",
                )

            self.comet.log_metrics(
                {"d_loss": d_loss, "g_loss": g_loss}, step=curr_step, epoch=epoch
            )
            fig = show_batch(fake_images, labels=disc_preds_to_label(d_fake_logits))
            self.comet.log_figure(
                figure=fig, figure_name="" f"fake_{epoch}_{it}.jpg", step=curr_step,
            )
            plt.close(fig)
            fig = show_batch(real_images, labels=disc_preds_to_label(d_real_logits))
            self.comet.log_figure(
                figure=fig, figure_name="" f"real_{epoch}_{it}.jpg", step=curr_step,
            )
            plt.close(fig)

    def train(self):
        print("Start training")
        for epoch in range(self.start_epoch.numpy(), self.cfg.train.epochs):
            with self.comet.train():
                self.train_epoch(epoch)
            self.curr_step.assign_add(self.steps_per_epoch)
            self.start_epoch.assign_add(1)
            if self.log and (((epoch + 1) % self.cfg.train.snapshot_interval) == 0):
                self.ckpt_manager.save(epoch + 1)

    def save_model(self, epoch):
        pass