def run(args, myargs): my_config = getattr(myargs.config, args.command) config = SearchConfig() for k, v in args.items(): assert not hasattr(config, k) setattr(config, k, v) for k, v in my_config.items(): if not hasattr(config, k): print('* config does not have %s' % k) setattr(config, k, v) device = torch.device("cuda") writer = myargs.writer writer.add_text('all_config', config.as_markdown(), 0) logger = myargs.logger config.print_params(logger.info_msg) config.data_path = os.path.expanduser(config.data_path) config.plot_path = os.path.join(args.outdir, 'plot') config.path = args.outdir main(config=config, logger=logger, device=device, myargs=myargs)
def get_current_node_index(): if "PAI_CURRENT_TASK_ROLE_CURRENT_TASK_INDEX" not in os.environ: return 0 return int(os.environ["PAI_CURRENT_TASK_ROLE_CURRENT_TASK_INDEX"]) if __name__ == "__main__": config = SearchConfig() if config.nni: if config.nni == "gt_mock": nni_tools.mock_result() else: config.designated_subgraph = [nni_tools.get_param()] config.path = nni_tools.get_output_dir() # tensorboard writer = SummaryWriter(log_dir=os.path.join(config.path, "tb")) writer.add_text('config', config.as_markdown(), 0) logger = utils.get_logger(os.path.join(config.path, "{}.log".format(config.name))) config.print_params(logger.info) main(config, writer, logger) writer.close() elif config.shared_policy == "group": config.shared_policy = "all" if "designated_partition" in config.shared_policy_kwargs: