Пример #1
0
    def train(self):
        """Train model on train set and evaluate on train and valid set.

        Returns:
            state dict of the best model with highest valid f1 score
        """
        epoch_logger = get_csv_logger(
            os.path.join(self.config.log_path,
                         self.config.experiment_name + '-epoch.csv'),
            title='epoch,train_acc,train_f1,valid_acc,valid_f1')
        step_logger = get_csv_logger(
            os.path.join(self.config.log_path,
                         self.config.experiment_name + '-step.csv'),
            title='step,loss')
        trange_obj = trange(self.config.num_epoch, desc='Epoch', ncols=120)
        # self._epoch_evaluate_update_description_log(
        #     tqdm_obj=trange_obj, logger=epoch_logger, epoch=0)
        best_model_state_dict, best_train_f1, global_step = None, 0, 0
        for epoch, _ in enumerate(trange_obj):
            self.model.train()
            tqdm_obj = tqdm(self.data_loader['train'], ncols=80)
            for step, batch in enumerate(tqdm_obj):
                batch = tuple(t.to(self.device) for t in batch)
                logits = self.model(*batch[:-1])  # the last one is label
                loss = self.criterion(logits, batch[-1])
                # if self.config.gradient_accumulation_steps > 1:
                #     loss = loss / self.config.gradient_accumulation_steps
                # self.optimizer.zero_grad()
                loss.backward()

                if (step + 1) % self.config.gradient_accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(), self.config.max_grad_norm)
                    #after 梯度累加的基本思想在于,在优化器更新参数前,也就是执行 optimizer.step() 前,进行多次反向传播,是的梯度累计值自动保存在 parameter.grad 中,最后使用累加的梯度进行参数更新。
                    self.optimizer.step()
                    self.scheduler.step()
                    self.optimizer.zero_grad()
                    global_step += 1
                    tqdm_obj.set_description('loss: {:.6f}'.format(loss.item()))
                    step_logger.info(str(global_step) + ',' + str(loss.item()))

            # if epoch >= 2:
            results = self._epoch_evaluate_update_description_log(
                tqdm_obj=trange_obj, logger=epoch_logger, epoch=epoch + 1)

            self.save_model(os.path.join(
                self.config.model_path, self.config.experiment_name,
                self.config.model_type + '-' + str(epoch + 1) + '.bin'))

            if results[-3] > best_train_f1:
                best_model_state_dict = deepcopy(self.model.state_dict())
                best_train_f1 = results[-3]

        return best_model_state_dict
Пример #2
0
    def valid(self):
        """Train model on train set and evaluate on train and valid set.

                Returns:
                    state dict of the best model with highest valid f1 score
                """
        epoch_logger = get_csv_logger(
            os.path.join(self.config.log_path,
                         self.config.experiment_name + '-epoch.csv'),
            title='epoch,train_acc,train_f1,valid_acc,valid_f1')
        step_logger = get_csv_logger(os.path.join(
            self.config.log_path, self.config.experiment_name + '-step.csv'),
                                     title='step,loss')
        # trange_obj = trange(self.config.num_epoch, desc='Epoch', ncols=120)
        # # self._epoch_evaluate_update_description_log(
        # #     tqdm_obj=trange_obj, logger=epoch_logger, epoch=0)
        # best_model_state_dict, best_train_f1, global_step = None, 0, 0

        train_results = self._epoch_evaluate_update_description_log(
            tqdm_obj=self.data_loader['valid_train'],
            logger=epoch_logger,
            epoch=-1 + 1,
            exam=self.data_loader['train_exam'],
            feats=self.data_loader['train_feat'],
            gold_file=self.config.train_file_path)

        valid_results = self._epoch_evaluate_update_description_log(
            tqdm_obj=self.data_loader['valid_valid'],
            logger=epoch_logger,
            epoch=-1 + 1,
            exam=self.data_loader['valid_exam'],
            feats=self.data_loader['valid_feat'],
            gold_file=self.config.valid_file_path)

        results = (train_results['f1'], train_results['sp_f1'],
                   train_results['joint_f1'], valid_results['f1'],
                   valid_results['sp_f1'], valid_results['joint_f1'])
        # self.save_model(os.path.join(
        #     self.config.model_path, self.config.experiment_name,
        #     self.config.model_type + '-' + str(epoch + 1) + '.bin'))
        #
        # if results[-4] > best_train_f1:
        # best_model_state_dict = deepcopy(self.model.state_dict())
        # best_train_f1 = results[-4]
        return results
Пример #3
0
    args = parser.parse_args()

    args.mem = args.recurrence > 1

    # Set run dir

    date = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S")

    model_name = args.model
    model_dir = utils.get_model_dir(model_name)

    # Load loggers and Tensorboard writer

    txt_logger = utils.get_txt_logger(model_dir)
    csv_file, csv_logger = utils.get_csv_logger(model_dir)
    tb_writer = tensorboardX.SummaryWriter(model_dir)

    # Log command and all script arguments

    txt_logger.info("{}\n".format(" ".join(sys.argv)))
    txt_logger.info("{}\n".format(args))

    # Set seed for all randomness sources

    utils.seed(args.seed)

    # Set device

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    txt_logger.info(f"Device: {device}\n")
Пример #4
0
if args.pretrained_gnn:
    assert(args.progression_mode == "full")
    default_dir = f"symbol-storage/{args.gnn}-dumb_ac_{args.ltl_sampler}_Simple-LTL-Env-v0_seed:{args.seed}_*_prog:{args.progression_mode}/train"
    print(default_dir)
    model_dirs = glob.glob(default_dir)
    if len(model_dirs) == 0:
        raise Exception("Pretraining directory not found.")
    elif len(model_dirs) > 1:
        raise Exception("More than 1 candidate pretraining directory found.")

    pretrained_model_dir = model_dirs[0]
# Load loggers and Tensorboard writer

txt_logger = utils.get_txt_logger(model_dir + "/train")
csv_file, csv_logger = utils.get_csv_logger(model_dir + "/train")
tb_writer = tensorboardX.SummaryWriter(model_dir + "/train")
utils.save_config(model_dir + "/train", args)

# Log command and all script arguments

txt_logger.info("{}\n".format(" ".join(sys.argv)))
txt_logger.info("{}\n".format(args))

# Set seed for all randomness sources

utils.seed(args.seed)

# Set device

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Пример #5
0
def main():
    # Parse arguments

    parser = argparse.ArgumentParser()

    ## General parameters
    parser.add_argument(
        "--algo",
        required=True,
        help="algorithm to use: a2c | ppo | ppo_intrinsic (REQUIRED)")
    parser.add_argument("--env",
                        required=True,
                        help="name of the environment to train on (REQUIRED)")
    parser.add_argument(
        "--model",
        default=None,
        help="name of the model (default: {ENV}_{ALGO}_{TIME})")
    parser.add_argument("--seed",
                        type=int,
                        default=1,
                        help="random seed (default: 1)")
    parser.add_argument("--log-interval",
                        type=int,
                        default=1,
                        help="number of updates between two logs (default: 1)")
    parser.add_argument(
        "--save-interval",
        type=int,
        default=10,
        help=
        "number of updates between two saves (default: 10, 0 means no saving)")
    parser.add_argument("--procs",
                        type=int,
                        default=16,
                        help="number of processes (default: 16)")
    parser.add_argument("--frames",
                        type=int,
                        default=10**7,
                        help="number of frames of training (default: 1e7)")

    ## Parameters for main algorithm
    parser.add_argument("--epochs",
                        type=int,
                        default=4,
                        help="number of epochs for PPO (default: 4)")
    parser.add_argument("--batch-size",
                        type=int,
                        default=256,
                        help="batch size for PPO (default: 256)")
    parser.add_argument(
        "--frames-per-proc",
        type=int,
        default=None,
        help=
        "number of frames per process before update (default: 5 for A2C and 128 for PPO)"
    )
    parser.add_argument("--discount",
                        type=float,
                        default=0.99,
                        help="discount factor (default: 0.99)")
    parser.add_argument("--lr",
                        type=float,
                        default=0.001,
                        help="learning rate (default: 0.001)")
    parser.add_argument(
        "--gae-lambda",
        type=float,
        default=0.95,
        help="lambda coefficient in GAE formula (default: 0.95, 1 means no gae)"
    )
    parser.add_argument("--entropy-coef",
                        type=float,
                        default=0.01,
                        help="entropy term coefficient (default: 0.01)")
    parser.add_argument("--value-loss-coef",
                        type=float,
                        default=0.5,
                        help="value loss term coefficient (default: 0.5)")
    parser.add_argument("--max-grad-norm",
                        type=float,
                        default=0.5,
                        help="maximum norm of gradient (default: 0.5)")
    parser.add_argument(
        "--optim-eps",
        type=float,
        default=1e-8,
        help="Adam and RMSprop optimizer epsilon (default: 1e-8)")
    parser.add_argument("--optim-alpha",
                        type=float,
                        default=0.99,
                        help="RMSprop optimizer alpha (default: 0.99)")
    parser.add_argument("--clip-eps",
                        type=float,
                        default=0.2,
                        help="clipping epsilon for PPO (default: 0.2)")
    parser.add_argument(
        "--recurrence",
        type=int,
        default=1,
        help=
        "number of time-steps gradient is backpropagated (default: 1). If > 1, a LSTM is added to the model to have memory."
    )
    parser.add_argument("--text",
                        action="store_true",
                        default=False,
                        help="add a GRU to the model to handle text input")
    parser.add_argument("--visualize",
                        default=False,
                        help="show real time CNN layer weight changes")

    args = parser.parse_args()

    args.mem = args.recurrence > 1

    # Set run dir

    date = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S")
    default_model_name = f"{args.env}_{args.algo}_seed{args.seed}_{date}"

    model_name = args.model or default_model_name
    model_dir = utils.get_model_dir(model_name)

    # Load loggers and Tensorboard writer

    txt_logger = utils.get_txt_logger(model_dir)
    csv_file, csv_logger = utils.get_csv_logger(model_dir)
    tb_writer = tensorboardX.SummaryWriter(model_dir)

    # Log command and all script arguments

    txt_logger.info("{}\n".format(" ".join(sys.argv)))
    txt_logger.info("{}\n".format(args))

    # Set seed for all randomness sources

    utils.seed(args.seed)

    # Set device

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    txt_logger.info(f"Device: {device}\n")

    # Load environments

    envs = []
    for i in range(args.procs):
        envs.append(utils.make_env(args.env, args.seed + 10000 * i))
    txt_logger.info("Environments loaded\n")

    # Load training status

    try:
        status = utils.get_status(model_dir)
    except OSError:
        status = {"num_frames": 0, "update": 0}
    txt_logger.info("Training status loaded\n")

    # Load observations preprocessor

    obs_space, preprocess_obss = utils.get_obss_preprocessor(
        envs[0].observation_space)
    if "vocab" in status:
        preprocess_obss.vocab.load_vocab(status["vocab"])
    txt_logger.info("Observations preprocessor loaded")

    # Load model

    acmodel = ACModel(obs_space, envs[0].action_space, args.mem, args.text)
    if "model_state" in status:
        acmodel.load_state_dict(status["model_state"])
    acmodel.to(device)
    txt_logger.info("Model loaded\n")
    txt_logger.info("{}\n".format(acmodel))

    # Load algo

    if args.algo == "a2c":
        algo = torch_ac.A2CAlgo(envs, acmodel, device, args.frames_per_proc,
                                args.discount, args.lr, args.gae_lambda,
                                args.entropy_coef, args.value_loss_coef,
                                args.max_grad_norm, args.recurrence,
                                args.optim_alpha, args.optim_eps,
                                preprocess_obss)
    elif args.algo == "ppo":
        algo = torch_ac.PPOAlgo(envs, acmodel, device, args.frames_per_proc,
                                args.discount, args.lr, args.gae_lambda,
                                args.entropy_coef, args.value_loss_coef,
                                args.max_grad_norm, args.recurrence,
                                args.optim_eps, args.clip_eps, args.epochs,
                                args.batch_size, preprocess_obss)

    elif args.algo == "ppo_intrinsic":
        algo = torch_ac.PPOAlgoIntrinsic(
            envs, acmodel, device, args.frames_per_proc, args.discount,
            args.lr, args.gae_lambda, args.entropy_coef, args.value_loss_coef,
            args.max_grad_norm, args.recurrence, args.optim_eps, args.clip_eps,
            args.epochs, args.batch_size, preprocess_obss)
    elif args.algo == "a2c_intrinsic":
        algo = torch_ac.A2CAlgoIntrinsic(
            envs, acmodel, device, args.frames_per_proc, args.discount,
            args.lr, args.gae_lambda, args.entropy_coef, args.value_loss_coef,
            args.max_grad_norm, args.recurrence, args.optim_alpha,
            args.optim_eps, preprocess_obss)
    else:
        raise ValueError("Incorrect algorithm name: {}".format(args.algo))

    if "optimizer_state" in status:
        algo.optimizer.load_state_dict(status["optimizer_state"])
    txt_logger.info("Optimizer loaded\n")

    # Train model

    num_frames = status["num_frames"]
    update = status["update"]
    start_time = time.time()

    print_visual = args.visualize
    if print_visual:
        fig, axs = plt.subplots(1, 3)
        fig.suptitle('Convolution Layer Weights Normalized Difference')

    while num_frames < args.frames:

        # Store copies of s_t model params
        old_parameters = {}
        for name, param in acmodel.named_parameters():
            old_parameters[name] = param.detach().numpy().copy()

        # Update model parameters
        update_start_time = time.time()
        exps, logs1 = algo.collect_experiences()
        logs2 = algo.update_parameters(exps)
        logs = {**logs1, **logs2}
        update_end_time = time.time()

        # Store copies of s_t+1 model params
        new_parameters = {}
        for name, param in acmodel.named_parameters():
            new_parameters[name] = param.detach().numpy().copy()

        # Compute L2 Norm of model state differences
        # Print model weight change visualization
        for index in range(len(old_parameters.keys())):
            if index == 0 or index == 2 or index == 4:
                key = list(old_parameters.keys())[index]
                old_weights = old_parameters[key]
                new_weights = new_parameters[key]
                norm_diff = numpy.linalg.norm(new_weights - old_weights)
                diff_matrix = abs(new_weights - old_weights)
                diff_matrix[:, :, 0, 0] = normalize(diff_matrix[:, :, 0, 0],
                                                    norm='max',
                                                    axis=0)
                if print_visual:
                    axs[int(index / 2)].imshow(diff_matrix[:, :, 0, 0],
                                               cmap='Greens',
                                               interpolation='nearest')

        # This allows the plots to update as the model trains
        if print_visual:
            plt.ion()
            plt.show()
            plt.pause(0.001)

        num_frames += logs["num_frames"]
        update += 1

        # Print logs

        if update % args.log_interval == 0:
            fps = logs["num_frames"] / (update_end_time - update_start_time)
            duration = int(time.time() - start_time)
            return_per_episode = utils.synthesize(logs["return_per_episode"])
            rreturn_per_episode = utils.synthesize(
                logs["reshaped_return_per_episode"])
            num_frames_per_episode = utils.synthesize(
                logs["num_frames_per_episode"])

            header = ["update", "frames", "FPS", "duration"]
            data = [update, num_frames, fps, duration]
            header += ["rreturn_" + key for key in rreturn_per_episode.keys()]
            data += rreturn_per_episode.values()
            header += [
                "num_frames_" + key for key in num_frames_per_episode.keys()
            ]
            data += num_frames_per_episode.values()
            header += [
                "entropy", "value", "policy_loss", "value_loss", "grad_norm"
            ]
            data += [
                logs["entropy"], logs["value"], logs["policy_loss"],
                logs["value_loss"], logs["grad_norm"]
            ]

            txt_logger.info(
                "U {} | F {:06} | FPS {:04.0f} | D {} | rR:μσmM {:.2f} {:.2f} {:.2f} {:.2f} | F:μσmM {:.1f} {:.1f} {} {} | H {:.3f} | V {:.3f} | pL {:.3f} | vL {:.3f} | ∇ {:.3f}"
                .format(*data))

            header += ["return_" + key for key in return_per_episode.keys()]
            data += return_per_episode.values()

            if status["num_frames"] == 0:
                csv_logger.writerow(header)
            csv_logger.writerow(data)
            csv_file.flush()

            for field, value in zip(header, data):
                tb_writer.add_scalar(field, value, num_frames)

        # Save status

        if args.save_interval > 0 and update % args.save_interval == 0:
            status = {
                "num_frames": num_frames,
                "update": update,
                "model_state": acmodel.state_dict(),
                "optimizer_state": algo.optimizer.state_dict()
            }
            if hasattr(preprocess_obss, "vocab"):
                status["vocab"] = preprocess_obss.vocab.vocab
            utils.save_status(status, model_dir)
            txt_logger.info("Status saved")
def main(raw_args=None):

    # Parse arguments
    parser = argparse.ArgumentParser()

    ## General parameters
    parser.add_argument("--algo",
                        required=True,
                        help="algorithm to use: a2c | ppo | ipo (REQUIRED)")
    parser.add_argument("--domain1",
                        required=True,
                        help="name of the first domain to train on (REQUIRED)")
    parser.add_argument(
        "--domain2",
        required=True,
        help="name of the second domain to train on (REQUIRED)")
    parser.add_argument(
        "--p1",
        required=True,
        type=float,
        help="Proportion of training environments from first domain (REQUIRED)"
    )
    parser.add_argument("--model", required=True, help="name of the model")
    parser.add_argument("--seed",
                        type=int,
                        default=1,
                        help="random seed (default: 1)")
    parser.add_argument("--log-interval",
                        type=int,
                        default=1,
                        help="number of updates between two logs (default: 1)")
    parser.add_argument(
        "--save-interval",
        type=int,
        default=10,
        help=
        "number of updates between two saves (default: 10, 0 means no saving)")
    parser.add_argument("--procs",
                        type=int,
                        default=16,
                        help="number of processes (default: 16)")
    parser.add_argument("--frames",
                        type=int,
                        default=10**7,
                        help="number of frames of training (default: 1e7)")

    ## Parameters for main algorithm
    parser.add_argument("--epochs",
                        type=int,
                        default=4,
                        help="number of epochs for PPO (default: 4)")
    parser.add_argument("--batch-size",
                        type=int,
                        default=256,
                        help="batch size for PPO (default: 256)")
    parser.add_argument(
        "--frames-per-proc",
        type=int,
        default=None,
        help=
        "number of frames per process before update (default: 5 for A2C and 128 for PPO)"
    )
    parser.add_argument("--discount",
                        type=float,
                        default=0.99,
                        help="discount factor (default: 0.99)")
    parser.add_argument("--lr",
                        type=float,
                        default=0.001,
                        help="learning rate (default: 0.001)")
    parser.add_argument(
        "--gae-lambda",
        type=float,
        default=0.95,
        help="lambda coefficient in GAE formula (default: 0.95, 1 means no gae)"
    )
    parser.add_argument("--entropy-coef",
                        type=float,
                        default=0.01,
                        help="entropy term coefficient (default: 0.01)")
    parser.add_argument("--value-loss-coef",
                        type=float,
                        default=0.5,
                        help="value loss term coefficient (default: 0.5)")
    parser.add_argument("--max-grad-norm",
                        type=float,
                        default=0.5,
                        help="maximum norm of gradient (default: 0.5)")
    parser.add_argument(
        "--optim-eps",
        type=float,
        default=1e-8,
        help="Adam and RMSprop optimizer epsilon (default: 1e-8)")
    parser.add_argument("--optim-alpha",
                        type=float,
                        default=0.99,
                        help="RMSprop optimizer alpha (default: 0.99)")
    parser.add_argument("--clip-eps",
                        type=float,
                        default=0.2,
                        help="clipping epsilon for PPO (default: 0.2)")
    parser.add_argument(
        "--recurrence",
        type=int,
        default=1,
        help=
        "number of time-steps gradient is backpropagated (default: 1). If > 1, a LSTM is added to the model to have memory."
    )
    parser.add_argument("--text",
                        action="store_true",
                        default=False,
                        help="add a GRU to the model to handle text input")

    args = parser.parse_args(raw_args)

    args.mem = args.recurrence > 1

    # Check PyTorch version
    if (torch.__version__ != '1.2.0'):
        raise ValueError(
            "PyTorch version must be 1.2.0 (see README). Your version is {}.".
            format(torch.__version__))

    if args.mem:
        raise ValueError("Policies with memory not supported.")

    # Set run dir

    date = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S")
    default_model_name = args.model

    model_name = args.model or default_model_name
    model_dir = utils.get_model_dir(model_name)

    # Load loggers and Tensorboard writer

    txt_logger = utils.get_txt_logger(model_dir)
    csv_file, csv_logger = utils.get_csv_logger(model_dir)
    tb_writer = tensorboardX.SummaryWriter(model_dir)

    # Log command and all script arguments

    txt_logger.info("{}\n".format(" ".join(sys.argv)))
    txt_logger.info("{}\n".format(args))

    # Set seed for all randomness sources

    torch.backends.cudnn.deterministic = True
    utils.seed(args.seed)

    # Set device

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    txt_logger.info(f"Device: {device}\n")

    # Load environments from different domains
    domain1 = args.domain1  # e.g., 'MiniGrid-ColoredKeysRed-v0'
    domain2 = args.domain2  # e.g., 'MiniGrid-ColoredKeysYellow-v0'

    p1 = args.p1  # Proportion of environments from domain1

    num_envs_total = args.procs  # Total number of environments
    num_domain1 = math.ceil(
        p1 * num_envs_total)  # Number of environments in domain1
    num_domain2 = num_envs_total - num_domain1  # Number of environments in domain2

    # Environments from domain1
    envs1 = []
    for i in range(num_domain1):
        envs1.append(utils.make_env(domain1, args.seed + 10000 * i))

    # Environments from domain2
    envs2 = []
    for i in range(num_domain2):
        envs2.append(utils.make_env(domain2, args.seed + 10000 * i))

    # All environments
    envs = envs1 + envs2

    txt_logger.info("Environments loaded\n")

    # Load training status

    try:
        status = utils.get_status(model_dir)
    except OSError:
        status = {"num_frames": 0, "update": 0}
    txt_logger.info("Training status loaded\n")

    # Load observations preprocessor

    obs_space, preprocess_obss = utils.get_obss_preprocessor(
        envs[0].observation_space)
    if "vocab" in status:
        preprocess_obss.vocab.load_vocab(status["vocab"])
    txt_logger.info("Observations preprocessor loaded")

    if args.algo == "ipo":
        # Load model for IPO game
        acmodel = ACModel_average(obs_space, envs[0].action_space, args.mem,
                                  args.text)
        if "model_state" in status:
            acmodel.load_state_dict(status["model_state"])
        acmodel.to(device)
        txt_logger.info("Model loaded\n")
        txt_logger.info("{}\n".format(acmodel))

    else:
        # Load model (for standard PPO or A2C)
        acmodel = ACModel(obs_space, envs[0].action_space, args.mem, args.text)
        if "model_state" in status:
            acmodel.load_state_dict(status["model_state"])
        acmodel.to(device)
        txt_logger.info("Model loaded\n")
        txt_logger.info("{}\n".format(acmodel))

    # Load algo

    if args.algo == "a2c":
        algo = torch_ac.A2CAlgo(envs, acmodel, device, args.frames_per_proc,
                                args.discount, args.lr, args.gae_lambda,
                                args.entropy_coef, args.value_loss_coef,
                                args.max_grad_norm, args.recurrence,
                                args.optim_alpha, args.optim_eps,
                                preprocess_obss)
        if "optimizer_state" in status:
            algo.optimizer.load_state_dict(status["optimizer_state"])
            txt_logger.info("Optimizer loaded\n")

    elif args.algo == "ppo":
        algo = torch_ac.PPOAlgo(envs, acmodel, device, args.frames_per_proc,
                                args.discount, args.lr, args.gae_lambda,
                                args.entropy_coef, args.value_loss_coef,
                                args.max_grad_norm, args.recurrence,
                                args.optim_eps, args.clip_eps, args.epochs,
                                args.batch_size, preprocess_obss)

        if "optimizer_state" in status:
            algo.optimizer.load_state_dict(status["optimizer_state"])
            txt_logger.info("Optimizer loaded\n")

    elif args.algo == "ipo":
        # One algo per domain. These have different envivonments, but shared acmodel
        algo1 = torch_ac.IPOAlgo(
            envs1, acmodel, 1, device, args.frames_per_proc, args.discount,
            args.lr, args.gae_lambda, args.entropy_coef, args.value_loss_coef,
            args.max_grad_norm, args.recurrence, args.optim_eps, args.clip_eps,
            args.epochs, args.batch_size, preprocess_obss)

        algo2 = torch_ac.IPOAlgo(
            envs2, acmodel, 2, device, args.frames_per_proc, args.discount,
            args.lr, args.gae_lambda, args.entropy_coef, args.value_loss_coef,
            args.max_grad_norm, args.recurrence, args.optim_eps, args.clip_eps,
            args.epochs, args.batch_size, preprocess_obss)

        if "optimizer_state1" in status:
            algo1.optimizer.load_state_dict(status["optimizer_state1"])
            txt_logger.info("Optimizer 1 loaded\n")
        if "optimizer_state2" in status:
            algo2.optimizer.load_state_dict(status["optimizer_state2"])
            txt_logger.info("Optimizer 2 loaded\n")

    else:
        raise ValueError("Incorrect algorithm name: {}".format(args.algo))

    # Train model

    num_frames = status["num_frames"]
    update = status["update"]
    start_time = time.time()

    while num_frames < args.frames:
        # Update model parameters

        update_start_time = time.time()

        if args.algo == "ipo":

            # Standard method

            # Collect experiences on first domain
            exps1, logs_exps1 = algo1.collect_experiences()

            # Update params of model corresponding to first domain
            logs_algo1 = algo1.update_parameters(exps1)

            # Collect experiences on second domain
            exps2, logs_exps2 = algo2.collect_experiences()

            # Update params of model corresponding to second domain
            logs_algo2 = algo2.update_parameters(exps2)

            # Update end time
            update_end_time = time.time()

            # Combine logs
            logs_exps = {
                'return_per_episode':
                logs_exps1["return_per_episode"] +
                logs_exps2["return_per_episode"],
                'reshaped_return_per_episode':
                logs_exps1["reshaped_return_per_episode"] +
                logs_exps2["reshaped_return_per_episode"],
                'num_frames_per_episode':
                logs_exps1["num_frames_per_episode"] +
                logs_exps2["num_frames_per_episode"],
                'num_frames':
                logs_exps1["num_frames"] + logs_exps2["num_frames"]
            }

            logs_algo = {
                'entropy':
                (num_domain1 * logs_algo1["entropy"] +
                 num_domain2 * logs_algo2["entropy"]) / num_envs_total,
                'value': (num_domain1 * logs_algo1["value"] +
                          num_domain2 * logs_algo2["value"]) / num_envs_total,
                'policy_loss':
                (num_domain1 * logs_algo1["policy_loss"] +
                 num_domain2 * logs_algo2["policy_loss"]) / num_envs_total,
                'value_loss':
                (num_domain1 * logs_algo1["value_loss"] +
                 num_domain2 * logs_algo2["value_loss"]) / num_envs_total,
                'grad_norm':
                (num_domain1 * logs_algo1["grad_norm"] +
                 num_domain2 * logs_algo2["grad_norm"]) / num_envs_total
            }

            logs = {**logs_exps, **logs_algo}
            num_frames += logs["num_frames"]

        else:
            exps, logs1 = algo.collect_experiences()
            logs2 = algo.update_parameters(exps)
            logs = {**logs1, **logs2}
            update_end_time = time.time()
            num_frames += logs["num_frames"]

        update += 1

        # Print logs

        if update % args.log_interval == 0:
            fps = logs["num_frames"] / (update_end_time - update_start_time)
            duration = int(time.time() - start_time)
            return_per_episode = utils.synthesize(logs["return_per_episode"])
            rreturn_per_episode = utils.synthesize(
                logs["reshaped_return_per_episode"])
            num_frames_per_episode = utils.synthesize(
                logs["num_frames_per_episode"])

            header = ["update", "frames", "FPS", "duration"]
            data = [update, num_frames, fps, duration]
            header += ["rreturn_" + key for key in rreturn_per_episode.keys()]
            data += rreturn_per_episode.values()
            header += [
                "num_frames_" + key for key in num_frames_per_episode.keys()
            ]
            data += num_frames_per_episode.values()
            header += [
                "entropy", "value", "policy_loss", "value_loss", "grad_norm"
            ]
            data += [
                logs["entropy"], logs["value"], logs["policy_loss"],
                logs["value_loss"], logs["grad_norm"]
            ]

            txt_logger.info(
                "U {} | F {:06} | FPS {:04.0f} | D {} | rR:μσmM {:.2f} {:.2f} {:.2f} {:.2f} | F:μσmM {:.1f} {:.1f} {} {} | H {:.3f} | V {:.3f} | pL {:.3f} | vL {:.3f} | ∇ {:.3f}"
                .format(*data))

            header += ["return_" + key for key in return_per_episode.keys()]
            data += return_per_episode.values()

            # header += ["debug_last_env_reward"]
            # data += [logs["debug_last_env_reward"]]

            header += ["total_loss"]
            data += [
                logs["policy_loss"] - args.entropy_coef * logs["entropy"] +
                args.value_loss_coef * logs["value_loss"]
            ]

            if status["num_frames"] == 0:
                csv_logger.writerow(header)

            csv_logger.writerow(data)
            csv_file.flush()

            for field, value in zip(header, data):
                tb_writer.add_scalar(field, value, num_frames)

        # Save status

        if args.save_interval > 0 and update % args.save_interval == 0:

            if args.algo == "ipo":
                status = {
                    "num_frames": num_frames,
                    "update": update,
                    "model_state": acmodel.state_dict(),
                    "optimizer_state1": algo1.optimizer.state_dict(),
                    "optimizer_state2": algo2.optimizer.state_dict()
                }
            else:
                status = {
                    "num_frames": num_frames,
                    "update": update,
                    "model_state": acmodel.state_dict(),
                    "optimizer_state": algo.optimizer.state_dict()
                }

            if hasattr(preprocess_obss, "vocab"):
                status["vocab"] = preprocess_obss.vocab.vocab
            utils.save_status(status, model_dir)
            txt_logger.info("Status saved")
Пример #7
0
    def train(self):
        """Train model on train set and evaluate on train and valid set.

        Returns:
            state dict of the best model with highest valid f1 score
        """
        epoch_logger = get_csv_logger(
            os.path.join(self.config.log_path,
                         self.config.experiment_name + '-epoch.csv'),
            title='epoch,train_acc,train_f1,valid_acc,valid_f1')
        step_logger = get_csv_logger(os.path.join(
            self.config.log_path, self.config.experiment_name + '-step.csv'),
                                     title='step,loss')
        trange_obj = trange(self.config.num_epoch, desc='Epoch', ncols=120)
        # self._epoch_evaluate_update_description_log(
        #     tqdm_obj=trange_obj, logger=epoch_logger, epoch=0)
        best_model_state_dict, best_train_f1, global_step = None, 0, 0
        for epoch, _ in enumerate(trange_obj):
            self.model.train()
            tqdm_obj = tqdm(self.data_loader['train'], ncols=80)
            for step, batch in enumerate(tqdm_obj):
                batch = tuple(t.to(self.device) for t in batch)
                # loss = self.criterion(logits, batch[-1])
                start_logits, end_logits, type_logits, sp_logits, start_position, end_position = self.model(
                    *batch)
                loss1 = self.criterion(start_logits,
                                       batch[6]) + self.criterion(
                                           end_logits, batch[7])  #y1, y2
                loss2 = self.config.type_lambda * self.criterion(
                    type_logits, batch[8])  # q_type
                loss3 = self.config.sp_lambda * self.sp_loss_fct(
                    sp_logits.view(-1), batch[10].float().view(
                        -1)).sum() / batch[9].sum()  # is_support

                loss = loss1 + loss2 + loss3

                # if self.config.gradient_accumulation_steps > 1:
                #     loss = loss / self.config.gradient_accumulation_steps
                # self.optimizer.zero_grad()
                # loss.backward()
                loss.backward()

                if (step + 1) % self.config.gradient_accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   self.config.max_grad_norm)
                    #after 梯度累加的基本思想在于,在优化器更新参数前,也就是执行 optimizer.step() 前,进行多次反向传播,是的梯度累计值自动保存在 parameter.grad 中,最后使用累加的梯度进行参数更新。
                    self.optimizer.step()
                    self.scheduler.step()
                    self.optimizer.zero_grad()
                    global_step += 1
                    tqdm_obj.set_description(
                        'loss: {:.6f} {:.6f} {:.6f}'.format(
                            loss1.item(), loss2.item(), loss3.item()))
                    step_logger.info(str(global_step) + ',' + str(loss.item()))

            train_results = self._epoch_evaluate_update_description_log(
                tqdm_obj=self.data_loader['valid_train'],
                logger=epoch_logger,
                epoch=epoch + 1,
                exam=self.data_loader['train_exam'],
                feats=self.data_loader['train_feat'])

            valid_results = self._epoch_evaluate_update_description_log(
                tqdm_obj=self.data_loader['valid_valid'],
                logger=epoch_logger,
                epoch=epoch + 1,
                exam=self.data_loader['valid_exam'],
                feats=self.data_loader['valid_feat'])

            results = (train_results['f1'], train_results['sp_f1'],
                       train_results['joint_f1'], valid_results['f1'],
                       valid_results['sp_f1'], valid_results['joint_f1'])
            self.save_model(
                os.path.join(
                    self.config.model_path, self.config.experiment_name,
                    self.config.model_type + '-' + str(epoch + 1) + '.bin'))

            if results[-4] > best_train_f1:
                best_model_state_dict = deepcopy(self.model.state_dict())
                best_train_f1 = results[-4]

        return best_model_state_dict
Пример #8
0
    if args.fp16:
        import apex
        apex.amp.register_half_function(torch, "einsum")
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

    model = torch.nn.DataParallel(model)
    model.train()

    # Training
    global_step = epc = 0
    total_train_loss = [0] * 5
    test_loss_record = []
    VERBOSE_STEP = args.verbose_step

    epoch_logger = get_csv_logger(os.path.join("log/", args.name + '-epoch.csv'),
        title='epoch,em,f1,prec,recall,sp_em,sp_f1,sp_prec,sp_recall,joint_em,joint_f1,joint_prec,joint_recall')
    while True:
        if epc == args.epochs:  # 5 + 30
            exit(0)
        epc += 1

        Loader = Full_Loader
        Loader.refresh()

        train_epoch(Loader, model, logger=epoch_logger, predict_during_train=False, epoch=epc)
        # if epc > 2:
        #
        # else:
        #     train_epoch(Loader, model, logger=None)
Пример #9
0
    def __init__(self,
                 env,
                 model_dir,
                 model_type='PPO2',
                 logger=None,
                 argmax=False,
                 use_memory=False,
                 use_text=False,
                 num_cpu=1,
                 frames_per_proc=None,
                 discount=0.99,
                 lr=0.001,
                 gae_lambda=0.95,
                 entropy_coef=0.01,
                 value_loss_coef=0.5,
                 max_grad_norm=0.5,
                 recurrence=1,
                 optim_eps=1e-8,
                 optim_alpha=None,
                 clip_eps=0.2,
                 epochs=4,
                 batch_size=256):
        """
        Initialize the Agent object.

        This primarily includes storing of the configuration parameters, but there is some other logic for correctly
        initializing the agent.

        :param env: the environment for training
        :param model_dir: the save directory (appended with the goal_id in initialization)
        :param model_type: the type of model {'PPO2', 'A2C'}
        :param logger: existing text logger
        :param argmax: if we use determinsitic or probabilistic action selection
        :param use_memory: if we are using an LSTM
        :param use_text: if we are using NLP to parse the goal
        :param num_cpu: the number of parallel instances for training
        :param frames_per_proc: max time_steps per process (versus constant)
        :param discount: the discount factor (gamma)
        :param lr: the learning rate
        :param gae_lambda: the generalized advantage estimator lambda parameter (training smoothing parameter)
        :param entropy_coef: relative weight for entropy loss
        :param value_loss_coef: relative weight for value function loss
        :param max_grad_norm: max scaling factor for the gradient
        :param recurrence: number of recurrent steps
        :param optim_eps: minimum value to prevent numerical instability
        :param optim_alpha: RMSprop decay parameter (A2C only)
        :param clip_eps: clipping parameter for the advantage and value function (PPO2 only)
        :param epochs: number of epochs in the parameter update (PPO2 only)
        :param batch_size: number of samples for the parameter update (PPO2 only)
        """
        if hasattr(
                env, 'goal'
        ) and env.goal:  # if the environment has a goal, set the model_dir to the goal folder
            self.model_dir = model_dir + env.goal.goalId + '/'
        else:  # otherwise just use the model_dir as is
            self.model_dir = model_dir

        # store all of the input parameters
        self.model_type = model_type
        self.num_cpu = num_cpu
        self.frames_per_proc = frames_per_proc
        self.discount = discount
        self.lr = lr
        self.gae_lambda = gae_lambda
        self.entropy_coef = entropy_coef
        self.value_loss_coef = value_loss_coef
        self.max_grad_norm = max_grad_norm
        self.recurrence = recurrence
        self.optim_eps = optim_eps
        self.optim_alpha = optim_alpha
        self.clip_eps = clip_eps
        self.epochs = epochs
        self.batch_size = batch_size

        # use the existing logger and create two new ones
        self.txt_logger = logger
        self.csv_file, self.csv_logger = utils.get_csv_logger(self.model_dir)
        self.tb_writer = tensorboardX.SummaryWriter(self.model_dir)

        self.set_env(
            env
        )  # set the environment to with some additional checks and init of training_envs

        self.algo = None  # we don't initialize the algorithm until we call init_training_algo()

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.txt_logger.info(f"Device: {device}\n")

        try:  # if we have a saved model, load it
            self.status = utils.get_status(self.model_dir)
        except OSError:  # otherwise initialize the status
            print('error loading saved model.  initializing empty model...')
            self.status = {"num_frames": 0, "update": 0}
        if self.txt_logger: self.txt_logger.info("Training status loaded\n")

        if "vocab" in self.status:
            preprocess_obss.vocab.load_vocab(self.status["vocab"])
        if self.txt_logger:
            self.txt_logger.info("Observations preprocessor loaded")

        # get the obs_space and the observation pre-processor
        # (for manipulating gym observations into a torch-friendly format)
        obs_space, self.preprocess_obss = utils.get_obss_preprocessor(
            self.env.observation_space)
        self.acmodel = ACModel(obs_space,
                               self.env.action_space,
                               use_memory=use_memory,
                               use_text=use_text)
        self.device = device  # store the device {'cpu', 'cuda:N'}
        self.argmax = argmax  # if we are using greedy action selection
        # or are we using probabilistic action selection

        if self.acmodel.recurrent:  # initialize the memories
            self.memories = torch.zeros(num_cpu,
                                        self.acmodel.memory_size,
                                        device=self.device)

        if "model_state" in self.status:  # if we have a saved model ('model_state') in the status
            # load that into the initialized model
            self.acmodel.load_state_dict(self.status["model_state"])
        self.acmodel.to(
            device)  # make sure the model is located on the correct device
        self.txt_logger.info("Model loaded\n")
        self.txt_logger.info("{}\n".format(self.acmodel))

        # some redundant code.  uncomment if there are issues and delete after enough testing
        #if 'model_state' in self.status:
        #    self.acmodel.load_state_dict(self.status['model_state'])
        #self.acmodel.to(self.device)
        self.acmodel.eval()
        if hasattr(self.preprocess_obss, "vocab"):
            self.preprocess_obss.vocab.load_vocab(utils.get_vocab(model_dir))
Пример #10
0
def train(parameters, config, gpu_list, do_test=False):
    get_path("log")
    epoch_logger = get_csv_logger(
        os.path.join("log", 'bert-epoch.csv'),
        title='epoch,train_acc,train_f1,valid_acc,valid_f1')

    epoch = config.getint("train", "epoch")
    batch_size = config.getint("train", "batch_size")

    output_time = config.getint("output", "output_time")
    test_time = config.getint("output", "test_time")

    output_path = os.path.join(config.get("output", "model_path"),
                               config.get("output", "model_name"))
    if os.path.exists(output_path):
        logger.warning(
            "Output path exists, check whether need to change a name of model")
    os.makedirs(output_path, exist_ok=True)

    trained_epoch = parameters["trained_epoch"] + 1
    model = parameters["model"]
    optimizer = parameters["optimizer"]
    dataset = parameters["train_dataset"]
    global_step = parameters["global_step"]
    output_function = parameters["output_function"]

    if do_test:
        init_formatter(config, ["test"])
        test_dataset = init_test_dataset(config)

    if trained_epoch == 0:
        shutil.rmtree(
            os.path.join(config.get("output", "tensorboard_path"),
                         config.get("output", "model_name")), True)

    os.makedirs(os.path.join(config.get("output", "tensorboard_path"),
                             config.get("output", "model_name")),
                exist_ok=True)

    writer = SummaryWriter(
        os.path.join(config.get("output", "tensorboard_path"),
                     config.get("output", "model_name")),
        config.get("output", "model_name"))

    step_size = config.getint("train", "step_size")
    gamma = config.getfloat("train", "lr_multiplier")

    gradient_accumulation_steps = config.getint("train",
                                                "gradient_accumulation_steps")
    max_grad_norm = config.getfloat("train", "max_grad_norm")

    exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
                                           step_size=step_size,
                                           gamma=gamma)
    exp_lr_scheduler.step(trained_epoch)

    logger.info("Training start....")

    print("Epoch  Stage  Iterations  Time Usage    Loss    Output Information")
    total_len = len(dataset)
    print('total len', total_len)
    more = ""
    if total_len < 10000:
        more = "\t"
    for epoch_num in range(trained_epoch, epoch):
        start_time = timer()
        current_epoch = epoch_num

        exp_lr_scheduler.step(current_epoch)

        acc_result = None
        total_loss = 0

        output_info = ""
        step = -1

        tqdm_obj = tqdm(dataset, ncols=80)
        for step, data in enumerate(tqdm_obj):
            for key in data.keys():
                if isinstance(data[key], torch.Tensor):
                    if len(gpu_list) > 0:
                        data[key] = Variable(data[key].cuda())
                    else:
                        data[key] = Variable(data[key])

            results = model(data, config, gpu_list, acc_result, "train")
            loss, acc_result = results["loss"], results["acc_result"]
            if gradient_accumulation_steps > 1:
                loss = loss / gradient_accumulation_steps
            total_loss += float(loss)
            loss.backward()
            # optimizer.zero_grad()

            if (step + 1) % gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               max_grad_norm)
                optimizer.step()
                optimizer.zero_grad()
                global_step += 1

            if step % output_time == 0:
                output_info = output_function(acc_result, config)

                delta_t = timer() - start_time

                output_value(
                    current_epoch, "train", "%d/%d" % (step + 1, total_len),
                    "%s/%s" % (gen_time_str(delta_t),
                               gen_time_str(delta_t * (total_len - step - 1) /
                                            (step + 1))),
                    "%.3lf" % (total_loss / (step + 1)), output_info, '\r',
                    config)

            writer.add_scalar(
                config.get("output", "model_name") + "_train_iter",
                float(loss), global_step)

        trainp_f1 = p_f1(acc_result)
        output_value(
            current_epoch, "train", "%d/%d" % (step + 1, total_len), "%s/%s" %
            (gen_time_str(delta_t),
             gen_time_str(delta_t * (total_len - step - 1) / (step + 1))),
            "%.3lf" % (total_loss / (step + 1)), output_info, None, config)

        if step == -1:
            logger.error(
                "There is no data given to the model in this epoch, check your data."
            )
            raise NotImplementedError

        checkpoint(os.path.join(output_path, "model%d.bin" % current_epoch),
                   model, optimizer, current_epoch, config, global_step)
        writer.add_scalar(
            config.get("output", "model_name") + "_train_epoch",
            float(total_loss) / (step + 1), current_epoch)

        if current_epoch % test_time == 0:
            with torch.no_grad():
                validp_f1 = valid(model, parameters["valid_dataset"],
                                  current_epoch, writer, config, gpu_list,
                                  output_function)
                if do_test:
                    valid(model,
                          test_dataset,
                          current_epoch,
                          writer,
                          config,
                          gpu_list,
                          output_function,
                          mode="test")

        # Logging
        l = []
        l.extend(trainp_f1)
        l.extend(validp_f1)
        l = [str(i) for i in l]
        epoch_logger.info(','.join([str(epoch_num)] + l))
Пример #11
0
def main(args):
    global total_train_loss

    def load_dataset():
        tokenizer = BertTokenizer.from_pretrained(args.bert_model)
        examples = read_examples(full_file=args.rawdata)

        with gzip.open("data_model/train_example.pkl.gz", 'wb') as fout:
            pickle.dump(examples, fout)

        features = convert_examples_to_features(examples, tokenizer, max_seq_length=args.max_seq_len, max_query_length=args.max_query_len)
        with gzip.open("data_model/train_feature.pkl.gz", 'wb') as fout:
            pickle.dump(features, fout)

        # tokenizer = BertTokenizer.from_pretrained(args.bert_model)
        examples = read_examples(full_file=args.validdata)
        with gzip.open("data_model/dev_example.pkl.gz", 'wb') as fout:
            pickle.dump(examples, fout)

        features = convert_examples_to_features(examples, tokenizer, max_seq_length=args.max_seq_len, max_query_length=args.max_query_len)
        with gzip.open("data_model/dev_feature.pkl.gz", 'wb') as fout:
            pickle.dump(features, fout)

        helper = DataHelper(gz=True, config=args)
        return helper


    helper = SERIAL_EXEC.run(load_dataset)

    args.n_type = helper.n_type  # 2

    # Set datasets
    Full_Loader = helper.train_loader
    # Subset_Loader = helper.train_sub_loader
    dev_example_dict = helper.dev_example_dict
    dev_feature_dict = helper.dev_feature_dict
    eval_dataset = helper.dev_loader

    # TPU
    device = xm.xla_device()
    model = WRAPPED_MODEL
    model.to(device)

    # roberta_config = BC.from_pretrained(args.bert_model)
    # encoder = BertModel.from_pretrained(args.bert_model)
    # args.input_dim=roberta_config.hidden_size
    # model = BertSupportNet(config=args, encoder=encoder)
    # if args.trained_weight:
    #     model.load_state_dict(torch.load(args.trained_weight))
    # model.to('cuda')

    # Initialize optimizer and criterions
    lr = args.lr
    t_total = len(Full_Loader) * args.epochs // args.gradient_accumulation_steps
    warmup_steps = args.warmup_step
    optimizer = AdamW(model.parameters(), lr=lr, eps=1e-8)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps,
                                                num_training_steps=t_total)
    criterion = nn.CrossEntropyLoss(reduction='mean', ignore_index=IGNORE_INDEX)  # 交叉熵损失
    binary_criterion = nn.BCEWithLogitsLoss(reduction='mean')  # 二元损失
    sp_loss_fct = nn.BCEWithLogitsLoss(reduction='none')  # 用于sp,平均值自己算

    if args.fp16:
        import apex
        apex.amp.register_half_function(torch, "einsum")
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

    # model = torch.nn.DataParallel(model)
    model.train()

    # Training
    global_step = epc = 0
    # total_train_loss = [0] * 5
    test_loss_record = []
    VERBOSE_STEP = args.verbose_step
    tracker = xm.RateTracker()

    epoch_logger = get_csv_logger(os.path.join("log/", args.name + '-epoch.csv'),
        title='epoch,em,f1,prec,recall,sp_em,sp_f1,sp_prec,sp_recall,joint_em,joint_f1,joint_prec,joint_recall')

    def train_fn(data_loader, dev_example_dict, dev_feature_dict, model, optimizer, scheduler,
                        criterion, sp_loss_fct, logger, predict_during_train=False, epoch=1, global_step=0,
                        test_loss_record=None):
        model.train()
        # pbar = tqdm(total=len(data_loader))
        epoch_len = len(data_loader)
        step_count = 0
        predict_step = epoch_len // 2
        for x, batch in enumerate(data_loader):
            step_count += 1
            # batch = next(iter(data_loader))
            batch['context_mask'] = batch['context_mask'].float()
            train_batch(model, optimizer, scheduler, criterion, sp_loss_fct, batch, global_step)
            global_step += 1
            # del batch
            # if predict_during_train and (step_count % predict_step == 0):
            #     predict(model, eval_dataset, dev_example_dict, dev_feature_dict,
            #             join(args.prediction_path,
            #                  'pred_seed_{}_epoch_{}_{}.json'.format(args.seed, epoch, step_count)))
            #     eval(join(args.prediction_path,
            #               'pred_seed_{}_epoch_{}_{}.json'.format(args.seed, epoch, step_count)), args.validdata)
            #     model_to_save = model.module if hasattr(model, 'module') else model
            #     torch.save(model_to_save.state_dict(), join(args.checkpoint_path,
            #                                                 "ckpt_seed_{}_epoch_{}_{}.pkl".format(args.seed, epoch,
            #                                                                                       step_count)))
            #     model.train()
            if xm.get_ordinal() == 0:
                if x % VERBOSE_STEP == 0:
                    print('[xla:{}]({}) Rate={:.2f} GlobalRate={:.2f} Time={}'.format(
                        xm.get_ordinal(), x, tracker.rate(),
                        tracker.global_rate(), time.asctime()), flush=True)

            # pbar.update(1)

    def test_fn(data_loader, dev_example_dict, dev_feature_dict, model, optimizer, scheduler,
                 criterion, sp_loss_fct, logger, predict_during_train=False, epoch=1, global_step=0,
                 test_loss_record=None):
        model.train()
        # pbar = tqdm(total=len(eval_dataset))
        epoch_len = len(data_loader)
        step_count = 0
        predict_step = epoch_len // 2
        # for x, batch in enumerate(data_loader):

        predict(model, data_loader, dev_example_dict, dev_feature_dict,
                join(args.prediction_path, 'pred_seed_{}_epoch_{}_99999.json'.format(args.seed, epoch)),
                test_loss_record)

        results = eval(join(args.prediction_path, 'pred_seed_{}_epoch_{}_99999.json'.format(args.seed, epoch)),
                       args.validdata)
        # Logging
        keys = 'em,f1,prec,recall,sp_em,sp_f1,sp_prec,sp_recall,joint_em,joint_f1,joint_prec,joint_recall'.split(
            ',')
        logger.info(','.join([str(epoch)] + [str(results[s]) for s in keys]))

        # model_to_save = model.module if hasattr(model, 'module') else model
        # torch.save(model_to_save.state_dict(), join(args.checkpoint_path, "model_{}.bin".format(epoch)))

    while True:
        if epc == args.epochs:  # 5 + 30
            exit(0)
        epc += 1

        Loader = Full_Loader
        Loader.refresh()

        para_loader = pl.ParallelLoader(Loader, [device])
        train_fn(para_loader,  dev_example_dict,
                    dev_feature_dict, model, optimizer, scheduler, criterion, sp_loss_fct, logger=epoch_logger,
                    predict_during_train=False, epoch=epc, global_step=global_step, test_loss_record=test_loss_record)
        xm.master_print("Finished training epoch {}".format(epc))

        eval_para_loader = pl.ParallelLoader(eval_dataset, [device])
        test_fn(eval_para_loader, dev_example_dict,
                    dev_feature_dict, model, optimizer, scheduler, criterion, sp_loss_fct, logger=epoch_logger,
                    predict_during_train=False, epoch=epc, global_step=global_step, test_loss_record=test_loss_record)
        xm.master_print("Finished training epoch {}".format(epc))
def tuner(icm_lr, reward_weighting, normalise_rewards, args):
    import argparse
    import datetime
    import torch
    import torch_ac
    import tensorboardX
    import sys
    import numpy as np
    from model import ACModel
    from .a2c import A2CAlgo

    # from .ppo import PPOAlgo

    frames_to_visualise = 200
    # Parse arguments

    args.mem = args.recurrence > 1

    def make_exploration_heatmap(args, plot_title):
        import numpy as np
        import matplotlib.pyplot as plt

        visitation_counts = np.load(
            f"{args.model}_visitation_counts.npy", allow_pickle=True
        )
        plot_title = str(np.count_nonzero(visitation_counts)) + args.model
        plt.imshow(np.log(visitation_counts))
        plt.colorbar()
        plt.title(plot_title)
        plt.savefig(f"{plot_title}_visitation_counts.png")

    # Set run dir

    date = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S")
    default_model_name = f"{args.env}_{args.algo}_seed{args.seed}_{date}"
    model_name = args.model or default_model_name
    model_dir = utils.get_model_dir(model_name)

    # Load loggers and Tensorboard writer

    txt_logger = utils.get_txt_logger(model_dir)
    csv_file, csv_logger = utils.get_csv_logger(model_dir)
    tb_writer = tensorboardX.SummaryWriter(model_dir)

    # Log command and all script arguments

    txt_logger.info("{}\n".format(" ".join(sys.argv)))
    txt_logger.info("{}\n".format(args))

    # Set seed for all randomness sources

    utils.seed(args.seed)

    # Set device

    device = "cpu"  # torch.device("cuda" if torch.cuda.is_available() else "cpu")
    txt_logger.info(f"Device: {device}\n")
    # Load environments

    envs = []

    for i in range(16):
        an_env = utils.make_env(
            args.env, int(args.frames_before_reset), int(args.environment_seed)
        )
        envs.append(an_env)
    txt_logger.info("Environments loaded\n")

    # Load training status

    try:
        status = utils.get_status(model_dir)
    except OSError:
        status = {"num_frames": 0, "update": 0}
    txt_logger.info("Training status loaded\n")

    # Load observations preprocessor

    obs_space, preprocess_obss = utils.get_obss_preprocessor(envs[0].observation_space)
    if "vocab" in status:
        preprocess_obss.vocab.load_vocab(status["vocab"])
    txt_logger.info("Observations preprocessor loaded")

    # Load model

    acmodel = ACModel(obs_space, envs[0].action_space, args.mem, args.text)
    if "model_state" in status:
        acmodel.load_state_dict(status["model_state"])
    acmodel.to(device)
    txt_logger.info("Model loaded\n")
    txt_logger.info("{}\n".format(acmodel))

    # Load algo

    # adapted from impact driven RL
    from .models import AutoencoderWithUncertainty

    autoencoder = AutoencoderWithUncertainty(observation_shape=(7, 7, 3)).to(device)

    autoencoder_opt = torch.optim.Adam(
        autoencoder.parameters(), lr=icm_lr, weight_decay=0
    )
    if args.algo == "a2c":
        algo = A2CAlgo(
            envs,
            acmodel,
            autoencoder,
            autoencoder_opt,
            args.uncertainty,
            args.noisy_tv,
            args.curiosity,
            args.randomise_env,
            args.uncertainty_budget,
            args.environment_seed,
            reward_weighting,
            normalise_rewards,
            args.frames_before_reset,
            device,
            args.frames_per_proc,
            args.discount,
            args.lr,
            args.gae_lambda,
            args.entropy_coef,
            args.value_loss_coef,
            args.max_grad_norm,
            args.recurrence,
            args.optim_alpha,
            args.optim_eps,
            preprocess_obss,
            None,
            args.random_action,
        )
    elif args.algo == "ppo":
        algo = PPOAlgo(
            envs,
            acmodel,
            autoencoder,
            autoencoder_opt,
            args.uncertainty,
            args.noisy_tv,
            args.curiosity,
            args.randomise_env,
            args.uncertainty_budget,
            args.environment_seed,
            reward_weighting,
            normalise_rewards,
            device,
            args.frames_per_proc,
            args.discount,
            args.lr,
            args.gae_lambda,
            args.entropy_coef,
            args.value_loss_coef,
            args.max_grad_norm,
            args.recurrence,
            args.optim_eps,
            args.clip_eps,
            args.epochs,
            args.batch_size,
            preprocess_obss,
        )

    else:
        raise ValueError("Incorrect algorithm name: {}".format(args.algo))

    if "optimizer_state" in status:
        algo.optimizer.load_state_dict(status["optimizer_state"])
    txt_logger.info("Optimizer loaded\n")

    # Train model

    num_frames = status["num_frames"]
    update = status["update"]
    start_time = time.time()

    while num_frames < args.frames:
        # Update model parameters

        update_start_time = time.time()
        exps, logs1 = algo.collect_experiences()
        logs2 = algo.update_parameters(exps)
        logs = {**logs1, **logs2}
        update_end_time = time.time()

        num_frames += logs["num_frames"]
        update += 1

        log_to_wandb(logs, start_time, update_start_time, update_end_time)

        # Print logs

        if update % args.log_interval == 0:
            fps = logs["num_frames"] / (update_end_time - update_start_time)
            duration = int(time.time() - start_time)
            return_per_episode = utils.synthesize(logs["return_per_episode"])
            rreturn_per_episode = utils.synthesize(logs["reshaped_return_per_episode"])
            num_frames_per_episode = utils.synthesize(logs["num_frames_per_episode"])
            header = ["update", "frames", "FPS", "duration"]
            data = [update, num_frames, fps, duration]
            header += ["rreturn_" + key for key in rreturn_per_episode.keys()]
            data += rreturn_per_episode.values()
            header += ["num_frames_" + key for key in num_frames_per_episode.keys()]
            data += num_frames_per_episode.values()
            header += [
                "intrinsic_rewards",
                "uncertainties",
                "novel_states_visited",
                "entropy",
                "value",
                "policy_loss",
                "value_loss",
                "grad_norm",
            ]
            data += [
                logs["intrinsic_rewards"].mean().item(),
                logs["uncertainties"].mean().item(),
                logs["novel_states_visited"].mean().item(),
                logs["entropy"],
                logs["value"],
                logs["policy_loss"],
                logs["value_loss"],
                logs["grad_norm"],
            ]
            txt_logger.info(
                "U {} | F {:06} | FPS {:04.0f} | D {} | rR:μσmM {:.2f} {:.2f} {:.2f} {:.2f} | F:μσmM {:.1f} {:.1f} {} {} | H {:.3f} | V {:.3f} | pL {:.3f}".format(
                    *data
                )
            )
        # Save status
        if args.save_interval > 0 and update % args.save_interval == 0:
            status = {
                "num_frames": num_frames,
                "update": update,
                "model_state": acmodel.state_dict(),
                "optimizer_state": algo.optimizer.state_dict(),
            }
            if hasattr(preprocess_obss, "vocab"):
                status["vocab"] = preprocess_obss.vocab.vocab
            utils.save_status(status, model_dir)
    return