예제 #1
0
    def _write_traces(self, split_name, train_steps, edit_traces, loss_traces):
        trace_dir = join(self.workspace.traces, split_name)
        trace_path = join(trace_dir, '{}.txt'.format(train_steps))
        makedirs(trace_dir)

        with codecs.open(trace_path, 'w', encoding='utf-8') as f:
            for edit_trace, loss_trace in zip(edit_traces, loss_traces):
                f.write(unicode(edit_trace))
                f.write('\n')
                f.write(unicode(loss_trace))
                f.write('\n\n')
예제 #2
0
파일: plot.py 프로젝트: lvyiwei1/StylePTB
def show(title, directory=''):
    """If in IPython, show, otherwise, save to file."""
    import matplotlib.pyplot as plt
    if in_ipython():
        plt.show()
    else:
        # ensure directory exists
        makedirs(directory)

        plt.savefig(os.path.join(directory, title) + '.png')
        # close all figures to conserve memory
        plt.close('all')
예제 #3
0
    def save(self, path):
        makedirs(path)

        # Store the latest random state
        self.random_state = RandomState()

        # save model
        torch.save(self.model.state_dict(), join(path, 'model'))
        torch.save(self.optimizer.state_dict(), join(path, 'optimizer'))

        # pickle remaining attributes
        d = {
            attr: getattr(self, attr)
            for attr in ['train_steps', 'random_state', 'max_grad_norm']
        }
        with open(join(path, 'metadata.pkl'), 'wb') as f:
            pickle.dump(d, f)
예제 #4
0
파일: log.py 프로젝트: SAGNIKMJR/wge
    def _log_traces(self, episodes, label, train_step):
        trace_dir = join(self.trace_dir, label)
        makedirs(trace_dir)
        trace_path = join(trace_dir, str(train_step))

        episode_traces = [MiniWoBEpisodeTrace(ep) for ep in episodes]

        # save machine-readable version
        with open(trace_path + '.json', 'w', 'utf8') as f:
            trace_dicts = [trace.to_json_dict() for trace in episode_traces]
            # print(trace_dicts[0])
            json.dump(str(trace_dicts), f, indent=2)

        # save pretty-printed version
        with open(trace_path + '.txt', 'w', 'utf8') as f:
            for i, trace in enumerate(episode_traces):
                f.write("=" * 25 + " EPISODE {} ".format(i) + "=" * 25)
                f.write('\n\n')
                f.write(trace.dumps())
                f.write('\n\n')

        # save screenshots
        for i, ep in enumerate(episodes):
            if not self._has_screenshot(ep):
                continue
            img_path = trace_path + '-img'
            makedirs(img_path)
            actions = []
            for j, experience in enumerate(ep):
                state, action = experience.state, experience.action
                path = join(img_path, '{}-{}-{}.png'.format(train_step, i, j))
                state.screenshot.save(path)
                actions.append(action.to_dict())
            # write action summary
            path = join(img_path, '{}-{}.json'.format(train_step, i))
            with open(path, 'w') as fout:
                json.dump(actions, fout)
예제 #5
0
 def __init__(self, absolute_path, sync=True):
     self._name = os.path.basename(os.path.normpath(absolute_path))
     self._absolute_path = absolute_path
     self._sync = sync
     makedirs(absolute_path)
     self._subdirs = []
예제 #6
0
 def clear_cache(self):
     shutil.rmtree(self.cache_dir)
     from gtd.io import makedirs
     makedirs(self.cache_dir)
예제 #7
0
 def tensorboard_dir(self):
     p = self.get_tensorboard_dir(self.save_dir)
     makedirs(p)
     return p
예제 #8
0
 def checkpoint_dir(self):
     p = self.get_checkpoint_dir(self.save_dir)
     makedirs(p)
     return p
 def tensorboard_dir(self):
     p = self.get_tensorboard_dir(self.save_dir)
     makedirs(p)
     return p
 def checkpoint_dir(self):
     p = self.get_checkpoint_dir(self.save_dir)
     makedirs(p)
     return p
 def clear_cache(self):
     shutil.rmtree(self.cache_dir)
     makedirs(self.cache_dir)
 def __init__(self, fxn, cache_dir, serialize, deserialize):
     super(FileMemoized, self).__init__(fxn)
     self.cache_dir = cache_dir
     self.serialize = serialize
     self.deserialize = deserialize
     makedirs(cache_dir)
예제 #13
0
 def __init__(self, absolute_path, sync=True):
     self._name = os.path.basename(os.path.normpath(absolute_path))
     self._absolute_path = absolute_path
     self._sync = sync
     makedirs(absolute_path)
     self._subdirs = []