示例#1
0
def train_agent(n_iterations):
    time_step = None
    policy_state = agent.collect_policy.get_initial_state(tf_env.batch_size)
    iterator = iter(dataset)
    for iteration in range(n_iterations):
        time_step, policy_state = collect_driver.run(time_step, policy_state)
        trajectories, buffer_info = next(iterator)
        train_loss = agent.train(trajectories)
        print("\r{} loss:{:.5f}".format(iteration, train_loss.loss.numpy()),
              end="")
        if iteration % config.TRAINING_LOG_INTERVAL == 0:
            utils.print_time_stats(train_start, iteration)
            print("\r")
            log_metrics(train_metrics)
        if iteration % config.TRAINING_SAVE_POLICY_INTERVAL == 0:
            save_agent_policy()
        if iteration % config.TRAINING_LOG_MEASURES_INTERVAL == 0:
            # calculate and report the total return over 1 episode
            utils.write_summary("AverageReturnMetric",
                                train_metrics[2].result(), iteration)
            utils.write_summary("AverageEpisodeLengthMetric",
                                train_metrics[3].result(), iteration)
            utils.writer.flush()

    save_agent_policy()
    utils.writer.flush()
示例#2
0
def train(train_loader, encoder, binarizer, decoder, optimizer, scheduler,
          epoch, iteration, args):
    '''Train model for a single epoch'''

    start_time = time.time()
    for i, (images, _) in enumerate(train_loader):
        batch_size = images.shape[0]
        images = images.cuda() if args.gpu else images
        data_time = time.time() - start_time

        # Create hidden states
        e_hidden_states = encoder.create_hidden(batch_size, gpu=args.gpu)
        d_hidden_states = decoder.create_hidden(batch_size, gpu=args.gpu)

        # Compress
        losses = []
        res = images
        for j in range(args.compression_iters):
            e_out, e_hidden_states = encoder(res, e_hidden_states)
            b_out = binarizer(e_out)
            d_out, d_hidden_states = decoder(b_out, d_hidden_states)
            res = res - d_out
            losses.append(res.abs().mean())  # mean absolute error

        # Backprop
        optimizer.zero_grad()
        loss = sum(losses) / args.compression_iters
        loss.backward()
        optimizer.step()
        compute_time = time.time() - data_time - start_time
        loss = loss.item()

        # Log
        iteration += 1
        if (iteration % args.log_every == 0):
            if args.tensorboard is not None:
                utils.write_summary(args.tensorboard, 'train_loss', loss,
                                    iteration)
                utils.write_summary(args.tensorboard, 'learning_rate',
                                    optimizer.current_lr, iteration)
                tf_summary_writer.flush()
            print(('[Train] Epoch {e:5d} '
                   '| Iter {i:6d} '
                   '| Loss {l:10.4f} '
                   '| Compute time {ct:8.2f} '
                   '| Data time {dt:8.2f} ').format(e=epoch,
                                                    i=iteration,
                                                    l=loss,
                                                    ct=compute_time,
                                                    dt=data_time))

    return iteration, epoch + 1, loss
示例#3
0
def train_epoch(model, dataloader,optimizer, epoch_index,writer,print_every=10):

    epoch_l1_loss = 0
    print_l1_loss = 0

    batches = len(dataloader)
    # print(batches)
    for iter_index, train_data in enumerate(dataloader):
        
        iter_index = iter_index + 1
        # print(iter_index)
        total_iter_index = iter_index + epoch_index * batches
        
        optimizer.zero_grad()
        data, label_r, ori_length_list = train_data
        sequences_mask = sequence_mask_torch(ori_length_list,max_len=hp.max_len)

        data, label_r = data.long(), label_r.float()
        if torch.cuda.is_available():

            data, label_r, ori_length_list = data.cuda(),label_r.cuda(),ori_length_list.cuda()
            sequences_mask = sequences_mask.cuda()
        predictions = model(data)
        # predictions = predictions.detach()
        # print(label_r)
        # print(label_r.requires_grad)
        l1_loss = loss_function(predictions,label_r)
        # print(l1_loss.requires_grad)
        l1_loss = torch.sum(l1_loss,dim=-1)
        l1_loss_mask = l1_loss * sequences_mask
        l1_loss_final = torch.sum(l1_loss_mask)
        l1_loss = l1_loss_final/torch.sum(sequences_mask)
        

        loss = l1_loss     
        print_l1_loss += l1_loss
        
        loss.backward()
        optimizer.step()
        
        if iter_index % print_every == 0:
            print("epoch:{}\titeration:[{}\{}]\tl1_loss:{}".format(epoch_index,iter_index,len(dataloader),print_l1_loss/print_every))
            write_summary(writer,print_l1_loss/print_every,total_iter_index)
            print_l1_loss = 0

        epoch_l1_loss += l1_loss

    return epoch_l1_loss/batches
示例#4
0
 def sample_sotl(self, policy, task=None, batch_id=None, params=None):
     for i in range(self.batch_size):
         self.queue.put(i)
     for _ in range(self.num_workers):
         self.queue.put(None)
     observations, batch_ids = self.envs.reset()
     dones = [False]
     if params:  # todo precise load parameter logic
         policy.load_params(params)
     while (not all(dones)):
         actions = policy.choose_action(observations)
         ## for multi_intersection
         actions = np.reshape(actions, (-1, 1))
         new_observations, rewards, dones, new_batch_ids, _ = self.envs.step(
             actions)
         observations, batch_ids = new_observations, new_batch_ids
     write_summary(self.dic_path, task, self.dic_exp_conf["EPISODE_LEN"], 0,
                   self.dic_traffic_env_conf['FLOW_FILE'])
示例#5
0
    def single_test_sample(self, policy, task, batch_id, params):
        policy.load_params(params)

        dic_traffic_env_conf = copy.deepcopy(self.dic_traffic_env_conf)
        dic_traffic_env_conf['TRAFFIC_FILE'] = task

        dic_path = copy.deepcopy(self.dic_path)
        dic_path["PATH_TO_LOG"] = os.path.join(
            dic_path['PATH_TO_WORK_DIRECTORY'], 'test_round', task,
            'tasks_round_' + str(batch_id))

        if not os.path.exists(dic_path['PATH_TO_LOG']):
            os.makedirs(dic_path['PATH_TO_LOG'])

        dic_exp_conf = copy.deepcopy(self.dic_exp_conf)

        env = CityFlowEnv(path_to_log=dic_path["PATH_TO_LOG"],
                          path_to_work_directory=dic_path["PATH_TO_DATA"],
                          dic_traffic_env_conf=dic_traffic_env_conf)

        done = False
        state = env.reset()
        step_num = 0
        stop_cnt = 0
        while not done and step_num < int(
                dic_exp_conf["EPISODE_LEN"] /
                dic_traffic_env_conf["MIN_ACTION_TIME"]):
            action_list = []
            for one_state in state:
                action = policy.choose_action(
                    [[one_state]], test=True
                )  # one for multi-state, the other for multi-intersection
                action_list.append(action[0])  # for multi-state

            next_state, reward, done, _ = env.step(action_list)
            state = next_state
            step_num += 1
            stop_cnt += 1
        env.bulk_log()
        write_summary(dic_path, task, self.dic_exp_conf["EPISODE_LEN"],
                      batch_id, self.dic_traffic_env_conf['FLOW_FILE'])
示例#6
0
  def train(self,
            sess,
            train_input_paths,
            train_target_mask_paths,
            dev_input_paths,
            dev_target_mask_paths):
    """
    Defines the training loop.

    Inputs:
    - sess: A TensorFlow Session object.
    - {train,dev}_{input_paths,target_mask_paths}: A list of Python strs
      that represent pathnames to input image files and target mask files.
    """
    params = tf.trainable_variables()
    num_params = sum(map(lambda t: np.prod(tf.shape(t.value()).eval()), params))

    # We will keep track of exponentially-smoothed loss
    exp_loss = None

    # Checkpoint management.
    # We keep one latest checkpoint, and one best checkpoint (early stopping)
    checkpoint_path = os.path.join(self.FLAGS.train_dir, "qa.ckpt")
    best_dev_dice_coefficient = None

    # For TensorBoard
    summary_writer = tf.summary.FileWriter(self.FLAGS.train_dir, sess.graph)

    epoch = 0
    num_epochs = self.FLAGS.num_epochs
    while num_epochs == None or epoch < num_epochs:
      epoch += 1

      # Loops over batches
      sbg = SliceBatchGenerator(train_input_paths,
                                train_target_mask_paths,
                                self.FLAGS.batch_size,
                                shape=(self.FLAGS.slice_height,
                                       self.FLAGS.slice_width),
                                use_fake_target_masks=self.FLAGS.use_fake_target_masks)
      num_epochs_str = str(num_epochs) if num_epochs != None else "indefinite"
      for batch in tqdm(sbg.get_batch(),
                        desc=f"Epoch {epoch}/{num_epochs_str}",
                        total=sbg.get_num_batches()):
        # Runs training iteration
        loss, global_step, param_norm, grad_norm =\
          self.run_train_iter(sess, batch, summary_writer)

        # Updates exponentially-smoothed loss
        if not exp_loss:  # first iter
          exp_loss = loss
        else:
          exp_loss = 0.99 * exp_loss + 0.01 * loss

        # Sometimes prints info
        if global_step % self.FLAGS.print_every == 0:
          logging.info(
            f"epoch {epoch}, "
            f"global_step {global_step}, "
            f"loss {loss}, "
            f"exp_loss {exp_loss}, "
            f"grad norm {grad_norm}, "
            f"param norm {param_norm}")

        # Sometimes saves model
        if (global_step % self.FLAGS.save_every == 0
            or global_step == sbg.get_num_batches()):
          self.saver.save(sess, checkpoint_path, global_step=global_step)

        # Sometimes evaluates model on dev loss, train F1/EM and dev F1/EM
        if global_step % self.FLAGS.eval_every == 0:
          # Logs loss for entire dev set to TensorBoard
          dev_loss = self.calculate_loss(sess,
                                         dev_input_paths,
                                         dev_target_mask_paths,
                                         "dev",
                                         self.FLAGS.dev_num_samples)
          logging.info(f"epoch {epoch}, "
                       f"global_step {global_step}, "
                       f"dev_loss {dev_loss}")
          utils.write_summary(dev_loss,
                              "dev/loss",
                              summary_writer,
                              global_step)

          # Logs dice coefficient on train set to TensorBoard
          train_dice = self.calculate_dice_coefficient(sess,
                                                       train_input_paths,
                                                       train_target_mask_paths,
                                                       "train")
          logging.info(f"epoch {epoch}, "
                       f"global_step {global_step}, "
                       f"train dice_coefficient: {train_dice}")
          utils.write_summary(train_dice,
                              "train/dice",
                              summary_writer,
                              global_step)

          # Logs dice coefficient on dev set to TensorBoard
          dev_dice = self.calculate_dice_coefficient(sess,
                                                     dev_input_paths,
                                                     dev_target_mask_paths,
                                                     "dev")
          logging.info(f"epoch {epoch}, "
                       f"global_step {global_step}, "
                       f"dev dice_coefficient: {dev_dice}")
          utils.write_summary(dev_dice,
                              "dev/dice",
                              summary_writer,
                              global_step)
      # end for batch in sbg.get_batch
    # end while num_epochs == 0 or epoch < num_epochs
    sys.stdout.flush()
示例#7
0
    def sample_meta_test(self,
                         policy,
                         task,
                         batch_id,
                         params=None,
                         target_params=None,
                         old_episodes=None):
        for i in range(self.batch_size):
            self.queue.put(i)
        for _ in range(self.num_workers):
            self.queue.put(None)
        episodes = BatchEpisodes(dic_agent_conf=self.dic_agent_conf,
                                 old_episodes=old_episodes)
        observations, batch_ids = self.envs.reset()
        dones = [False]
        if params:  # todo precise load parameter logic
            policy.load_params(params)

        while (not all(dones)) or (not self.queue.empty()):
            actions = policy.choose_action(observations)
            ## for multi_intersection
            actions = np.reshape(actions, (-1, 1))
            new_observations, rewards, dones, new_batch_ids, _ = self.envs.step(
                actions)
            episodes.append(observations, actions, new_observations, rewards,
                            batch_ids)
            observations, batch_ids = new_observations, new_batch_ids

            if self.step > self.dic_agent_conf[
                    'UPDATE_START'] and self.step % self.dic_agent_conf[
                        'UPDATE_PERIOD'] == 0:
                if len(episodes) > self.dic_agent_conf['MAX_MEMORY_LEN']:
                    episodes.forget()

                policy.fit(episodes,
                           params=params,
                           target_params=target_params)
                sample_size = min(self.dic_agent_conf['SAMPLE_SIZE'],
                                  len(episodes))
                slice_index = random.sample(range(len(episodes)), sample_size)
                params = policy.update_params(episodes,
                                              params=copy.deepcopy(params),
                                              lr_step=self.lr_step,
                                              slice_index=slice_index)

                policy.load_params(params)

                self.lr_step += 1
                self.target_step += 1
                if self.target_step == self.dic_agent_conf[
                        'UPDATE_Q_BAR_FREQ']:
                    target_params = params
                    self.target_step = 0

            if self.step > self.dic_agent_conf[
                    'UPDATE_START'] and self.step % self.dic_agent_conf[
                        'TEST_PERIOD'] == 0:
                self.single_test_sample(policy,
                                        task,
                                        self.test_step,
                                        params=params)
                pickle.dump(
                    params,
                    open(
                        os.path.join(
                            self.dic_path['PATH_TO_MODEL'],
                            'params' + "_" + str(self.test_step) + ".pkl"),
                        'wb'))
                write_summary(self.dic_path, task,
                              self.dic_traffic_env_conf["EPISODE_LEN"],
                              batch_id)

                self.test_step += 1
            self.step += 1

        policy.decay_epsilon(batch_id)
        self.envs.bulk_log()
        return params, target_params, episodes
示例#8
0
def train(generator,
          discriminator,
          g_optim,
          d_optim,
          scheduler,
          data_loader,
          mixing_epochs,
          stabilizing_epochs,
          phase,
          writer,
          horovod=False):

    alpha = 1  # Mixing parameter.
    for epoch in range(mixing_epochs):
        if horovod:
            hvd.broadcast_parameters(generator.state_dict(), root_rank=0)
            hvd.broadcast_optimizer_state(g_optim, root_rank=0)
            hvd.broadcast_parameters(discriminator.state_dict(), root_rank=0)
            hvd.broadcast_optimizer_state(d_optim, root_rank=0)
            data_loader.sampler.set_epoch(epoch)

        start = time.perf_counter()
        x_fake, x_real, *scalars = train_epoch(data_loader, generator,
                                               discriminator, g_optim, d_optim,
                                               alpha)

        end = time.perf_counter()

        images_per_second = len(data_loader.dataset) / (end - start)

        # Tensorboard
        g_lr = g_optim.param_groups[0]['lr']
        d_lr = d_optim.param_groups[0]['lr']
        scalars = list(scalars) + [epoch, alpha, g_lr, d_lr, images_per_second]
        global_step = (phase - 1) * (mixing_epochs +
                                     stabilizing_epochs) + epoch
        images_seen = global_step * len(data_loader) * data_loader.batch_size
        if horovod:
            images_seen = images_seen * hvd.size()

        if writer:
            write_summary(writer, images_seen, x_real[0], x_fake[0], scalars)

        # Update alpha
        alpha -= 1 / mixing_epochs
        assert alpha >= -1e-4, alpha

        scheduler.step()

        if epoch % 16 == 0 and writer:
            print(
                f'Epoch: {epoch} \t Images Seen: {images_seen} \t '
                f'Discriminator Loss: {scalars[0]:.4f} \t Generator Loss: {scalars[1]:.4f}'
            )
            discriminator.eval()
            generator.eval()
            torch.save(
                discriminator.state_dict(),
                os.path.join(writer.log_dir,
                             f'discriminator_phase_{phase}_epoch_{epoch}.pt'))
            torch.save(
                generator.state_dict(),
                os.path.join(writer.log_dir,
                             f'generator_phase_{phase}_epoch_{epoch}.pt'))

    d_dict = get_metrics(x_real.detach().cpu().numpy(),
                         x_fake.detach().cpu().numpy())

    if writer:
        for d in d_dict:
            writer.add_scalar(d, d_dict[d], global_step)

    alpha = 0

    for epoch in range(mixing_epochs, mixing_epochs + stabilizing_epochs):
        if horovod:
            hvd.broadcast_parameters(generator.state_dict(), root_rank=0)
            hvd.broadcast_optimizer_state(g_optim, root_rank=0)
            hvd.broadcast_parameters(discriminator.state_dict(), root_rank=0)
            hvd.broadcast_optimizer_state(d_optim, root_rank=0)
            data_loader.sampler.set_epoch(epoch)

        start = time.perf_counter()
        x_fake, x_real, *scalars = train_epoch(data_loader, generator,
                                               discriminator, g_optim, d_optim,
                                               alpha)
        end = time.perf_counter()

        images_per_second = len(data_loader.dataset) / (end - start)
        # Tensorboard
        d_dict = get_metrics(x_real.detach().cpu().numpy(),
                             x_fake.detach().cpu().numpy())
        g_lr = g_optim.param_groups[0]['lr']
        d_lr = d_optim.param_groups[0]['lr']
        scalars = list(scalars) + [epoch, alpha, g_lr, d_lr, images_per_second]
        global_step = (phase - 1) * (mixing_epochs +
                                     stabilizing_epochs) + epoch
        images_seen = global_step * len(data_loader) * data_loader.batch_size
        if horovod:
            images_seen = images_seen * hvd.size()

        if writer:
            write_summary(writer, images_seen, x_real[0], x_fake[0], scalars)

        if epoch % 16 == 0 and writer:
            print(
                f'Epoch: {epoch} \t Images Seen: {images_seen} \t '
                f'Discriminator Loss: {scalars[0]:.4f} \t Generator Loss: {scalars[1]:.4f}'
            )

            discriminator.eval()
            generator.eval()
            torch.save(
                discriminator.state_dict(),
                os.path.join(writer.log_dir,
                             f'discriminator_phase_{phase}_epoch_{epoch}.pt'))
            torch.save(
                generator.state_dict(),
                os.path.join(writer.log_dir,
                             f'generator_phase_{phase}_epoch_{epoch}.pt'))

    d_dict = get_metrics(x_real.detach().cpu().numpy(),
                         x_fake.detach().cpu().numpy())
    if writer:
        for d in d_dict:
            writer.add_scalar(d, d_dict[d], global_step)
示例#9
0
def train_epoch(model,
                dataloader,
                optimizer,
                epoch_index,
                writer,
                print_every=10):
    epoch_loss = 0
    epoch_l1_loss = 0
    epoch_smooth_loss = 0

    print_loss = 0
    print_l1_loss = 0
    print_smooth_loss = 0
    batches = len(dataloader)

    for iter_index, train_data in enumerate(dataloader):
        iter_index += 1
        total_iter_index = iter_index + epoch_index * batches

        optimizer.zero_grad()

        data, position, label_r, ori_length_list = train_data
        sequences_mask = sequence_mask_torch(ori_length_list,
                                             max_len=hp.max_len)
        # print(sequences_mask)
        # print("sequence_mask.shape:")
        # print(sequences_mask.shape)
        data, position, label_r = data.long(), position.long(), label_r.float()
        if torch.cuda.is_available():

            data, position, label_r, ori_length_list = data.cuda(
            ), position.cuda(), label_r.cuda(), ori_length_list.cuda()
            sequences_mask = sequences_mask.cuda()
        # print("sequence mask size:")
        # print(sequences_mask.size())
        predictions = model(data)
        loss, l1_loss, smooth_loss = compute_loss(predictions, label_r,
                                                  ori_length_list)
        # l1_loss = loss_function(label_r,predictions)
        # print("l1 loss size:")
        # print(l1_loss.size())
        # l1_loss = torch.sum(l1_loss,dim=-1)
        # l1_loss_mask = l1_loss * sequences_mask
        # l1_loss_final = torch.sum(l1_loss_mask)
        # l1_loss = l1_loss_final/torch.sum(sequences_mask)

        # smooth_loss = get_smooth_loss(predictions)
        # smooth_loss = torch.sum(smooth_loss,dim=-1)
        # smooth_loss = torch.sum(smooth_loss*sequences_mask[:,1:])/torch.sum(sequences_mask[:,1:])
        # loss = 0.3 * l1_loss + 0.7*smooth_loss

        print_l1_loss += l1_loss
        print_loss += loss
        print_smooth_loss += smooth_loss

        loss.backward()
        optimizer.step()

        if iter_index % print_every == 0:
            print(
                "epoch:{}\titeration:[{}\{}]\tl1_loss:{}\tsmooth_loss:{}\tloss:{}"
                .format(epoch_index, iter_index, len(dataloader),
                        print_l1_loss / print_every,
                        print_smooth_loss / print_every,
                        print_loss / print_every))
            write_summary(writer, print_l1_loss / print_every,
                          print_smooth_loss / print_every,
                          print_loss / print_every, total_iter_index)
            print_l1_loss = 0
            print_loss = 0
            print_smooth_loss = 0

        epoch_loss += loss
        epoch_l1_loss += l1_loss
        epoch_smooth_loss += smooth_loss

    return epoch_l1_loss / batches, epoch_smooth_loss / batches, epoch_loss / batches