def _update_network(self):
     # sample the episodes
     transitions = self.buffer.sample(self.args.batch_size)
     # pre-process the observation and goal
     o, o_next, g = transitions['obs'], transitions[
         'obs_next'], transitions['g']
     transitions['obs'], transitions['g'] = self._preproc_og(o, g)
     transitions['obs_next'], transitions['g_next'] = self._preproc_og(
         o_next, g)
     # start to do the update
     obs_norm = self.o_norm.normalize(transitions['obs'])
     g_norm = self.g_norm.normalize(transitions['g'])
     inputs_norm = np.concatenate([obs_norm, g_norm], axis=1)
     obs_next_norm = self.o_norm.normalize(transitions['obs_next'])
     g_next_norm = self.g_norm.normalize(transitions['g_next'])
     inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1)
     # transfer them into the tensor
     inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32)
     inputs_next_norm_tensor = torch.tensor(inputs_next_norm,
                                            dtype=torch.float32)
     actions_tensor = torch.tensor(transitions['actions'],
                                   dtype=torch.float32)
     r_tensor = torch.tensor(transitions['r'], dtype=torch.float32)
     if self.args.cuda:
         inputs_norm_tensor = inputs_norm_tensor.cuda()
         inputs_next_norm_tensor = inputs_next_norm_tensor.cuda()
         actions_tensor = actions_tensor.cuda()
         r_tensor = r_tensor.cuda()
     # calculate the target Q value function
     with torch.no_grad():
         # do the normalization
         # concatenate the stuffs
         actions_next = self.actor_target_network(inputs_next_norm_tensor)
         q_next_value = self.critic_target_network(inputs_next_norm_tensor,
                                                   actions_next)
         q_next_value = q_next_value.detach()
         target_q_value = r_tensor + self.args.gamma * q_next_value
         target_q_value = target_q_value.detach()
         # clip the q value
         clip_return = 1 / (1 - self.args.gamma)
         target_q_value = torch.clamp(target_q_value, -clip_return, 0)
     # the q loss
     real_q_value = self.critic_network(inputs_norm_tensor, actions_tensor)
     critic_loss = (target_q_value - real_q_value).pow(2).mean()
     # the actor loss
     actions_real = self.actor_network(inputs_norm_tensor)
     actor_loss = -self.critic_network(inputs_norm_tensor,
                                       actions_real).mean()
     actor_loss += self.args.action_l2 * (
         actions_real / self.env_params['action_max']).pow(2).mean()
     # start to update the network
     self.actor_optim.zero_grad()
     actor_loss.backward()
     sync_grads(self.actor_network)
     self.actor_optim.step()
     # update the critic_network
     self.critic_optim.zero_grad()
     critic_loss.backward()
     sync_grads(self.critic_network)
     self.critic_optim.step()
def train(rank,
          args,
          tokenizer,
          train_dataset,
          test_dataset,
          model_s,
          model_t,
          params_to_tune,
          head_importance=None,
          loss_num=-1,
          tune_iter=0):
    """ Train the model """
    global train_count
    train_count += 1

    world_size = 1 if rank < 0 else torch.distributed.get_world_size()

    if rank in [-1, 0]:
        printlog("Train stage: ", train_count)
        printlog(model_s)

    if head_importance is not None:
        head_mask = torch.ones(*list(head_importance.shape)).to(args.device)
        head_mask.requires_grad_(requires_grad=True)
    else:
        head_mask = None

    num_train_epochs = args.num_train_epochs
    if loss_num > 0:
        num_train_epochs = 0.25  #short train for incremental loss

    per_gpu_train_batch_size = args.per_gpu_train_batch_size
    train_batch_size = per_gpu_train_batch_size * world_size

    #get total batch size and
    if tune_iter > 0 and args.total_train_batch_size_for_tune:
        total_train_batch_size = args.total_train_batch_size_for_tune
    else:
        total_train_batch_size = args.total_train_batch_size
    gradient_accumulation_steps = total_train_batch_size // train_batch_size

    if tune_iter > 0 and args.learning_rate_for_tune:
        learning_rate = args.learning_rate_for_tune
    else:
        learning_rate = args.learning_rate

    if check_model_type(model_s, BertModelEMB):
        #use 2 datasets for embedding question and context separatly
        if rank in [-1, 0]:
            printlog("dataset_q size", len(train_dataset.q_dataset))
            printlog("dataset_c size", len(train_dataset.c_dataset))
        datasets = [train_dataset.q_dataset, train_dataset.c_dataset]
    else:
        if rank in [-1, 0]:
            printlog("dataset size", len(train_dataset))
        datasets = [train_dataset]

    if rank > -1:
        #for distributed train use sample that take only part of samples for each process
        train_dataloaders = [
            DataLoader(dataset,
                       sampler=torch.utils.data.distributed.DistributedSampler(
                           dataset, rank=rank),
                       batch_size=per_gpu_train_batch_size)
            for dataset in datasets
        ]
    else:
        train_dataloaders = [
            DataLoader(dataset,
                       sampler=RandomSampler(dataset),
                       batch_size=train_batch_size,
                       num_workers=4) for dataset in datasets
        ]

    steps_per_epoch = sum(len(d) for d in train_dataloaders)
    steps_total = int(steps_per_epoch // gradient_accumulation_steps *
                      num_train_epochs)

    # Prepare optimizer and scheduler
    name_set = set()
    for n, p in model_s.named_parameters():
        if any(p is pp for pp in params_to_tune):
            name_set.add(n)
    named_params = [(n, p) for n, p in model_s.named_parameters()
                    if n in name_set]

    if rank in [-1, 0]:
        for n, p in named_params:
            printlog('param for tune', n)

    def new_optimizer():
        return AdamW([p for n, p in named_params],
                     lr=learning_rate,
                     eps=1e-08,
                     weight_decay=0.0)

    optimizer = new_optimizer()

    def lr_lambda(current_step):
        p = float(current_step) / float(steps_total)
        warmup = 0.01
        if p < warmup:
            return p / warmup
        p = (p - warmup) / (1 - warmup)
        return 1 if tune_iter == 0 else max(1 - p, 0)

    scheduler = LambdaLR(optimizer, lr_lambda)

    if rank in [-1, 0]:
        printlog("epoches", num_train_epochs)
        printlog("per_gpu_train_batch_size", per_gpu_train_batch_size)
        printlog("n_gpu", args.n_gpu)
        printlog("world_size", world_size)
        printlog("gradient_accumulation_steps", gradient_accumulation_steps)
        printlog("total train batch size",
                 train_batch_size * gradient_accumulation_steps)
        printlog("steps_total", steps_total)

    restore_count = 0
    if rank in [-1, 0]:
        if not os.path.exists(args.output_dir):
            os.makedirs(args.output_dir)
    restore_file = os.path.join(args.output_dir, 'last_good_state.pth')
    restore_loss = None

    losses_list = []

    global_step = 0
    for epoch in range(math.ceil(num_train_epochs)):
        switch_to_train(rank, model_t)
        switch_to_train(rank, model_s)
        model_s.zero_grad()
        utils.sync_models(rank, model_s)

        time_last = time.time()
        for train_dataloader in train_dataloaders:
            printlog("rank", rank, "len(train_dataloader)",
                     len(train_dataloader))
            if rank > -1:
                train_dataloader.sampler.set_epoch(epoch)

            if len(train_dataloaders) > 1:
                # reset last loss to avoid restore due to dataset changing
                printlog("rank", rank, "reset restore_loss")
                restore_loss = None

            for step, batch in enumerate(train_dataloader):
                epoch_fp = epoch + step / len(train_dataloader)
                if epoch_fp > num_train_epochs:
                    break

                inputs = {
                    'input_ids': batch[0].to(args.device),
                    'attention_mask': batch[1].to(args.device),
                    'token_type_ids': batch[2].to(args.device)
                }

                outputs_s = model_s(**inputs,
                                    head_mask=head_mask,
                                    output_hidden_states=True)
                losses = []

                with torch.no_grad():
                    outputs_t = model_t(**inputs, output_hidden_states=True)

                out_s, out_t = outputs_s[-1], outputs_t[-1]

                assert len(
                    out_s
                ) == model_s.config.num_hidden_layers + 1, "can not find hidden states in student model outputs"
                assert len(
                    out_t
                ) == model_t.config.num_hidden_layers + 1, "can not find hidden states in teacher model outputs"
                if len(out_s) != len(out_t):
                    #the student and teacher outputs are not aligned. try to find teacher output for each student output
                    n_s, n_t = len(out_s), len(out_t)
                    out_t = [
                        out_t[(i * (n_t - 1)) // (n_s - 1)] for i in range(n_s)
                    ]
                assert len(out_s) == len(
                    out_t
                ), "can not align number of outputs between student and teacher"
                assert all(
                    s[0] == s[1] for s in zip(out_s[0].shape, out_t[0].shape)
                ), "output shapes for student and teacher are not the same"

                out_pairs = list(zip(out_s, out_t))
                if loss_num > 0:
                    out_pairs = out_pairs[:loss_num]

                losses += [(s - t.detach()).pow(2).mean()
                           for s, t in out_pairs]

                losses_list.append([l.item() for l in losses])

                if tune_iter == 0:
                    loss = sum(losses) / len(losses)
                else:
                    weights = [
                        args.loss_weight_alpha**i for i in range(len(losses))
                    ]
                    losses_w = [w * l for w, l in zip(weights, losses)]
                    loss = sum(losses_w) / sum(weights)

                if gradient_accumulation_steps > 1:
                    loss = loss / gradient_accumulation_steps

                loss.backward()
                del out_s
                del out_t
                del outputs_s
                del outputs_t

                if head_importance is not None:
                    #collect gradient statistics to find most valuable heads
                    head_mask.grad.detach_()
                    head_importance += (head_mask.grad.abs().detach() -
                                        head_importance) * 0.001
                    head_mask.grad.zero_()

                if (step + 1) % gradient_accumulation_steps == 0:
                    global_step += 1

                    #sync gradients before calc step
                    utils.sync_grads(rank, named_params, global_step == 1)

                    torch.nn.utils.clip_grad_norm_(
                        [p for n, p in named_params], 1)
                    optimizer.step()
                    scheduler.step()

                    model_s.zero_grad()

                    if (step + 1) % 50 == 0:
                        str_out = "{} ep {:.2f} lrp {:.2f} rc {:02}".format(
                            train_count, epoch_fp,
                            np.log10(scheduler.get_last_lr()[0]),
                            restore_count)
                        ll = np.array(losses_list).mean(0)

                        if rank > -1:
                            #sync indicators
                            llt = torch.tensor(ll).to(args.device)
                            torch.distributed.all_reduce(
                                llt, op=torch.distributed.ReduceOp.SUM)
                            ll = llt.cpu().numpy() / float(world_size)

                        loss = ll.mean()
                        str_out += " loss {:.4f}".format(loss)
                        losses_txt = ["{:.3f}".format(l) for l in ll]
                        if tune_iter > 0:
                            losses_txt = [
                                "{:.2f}x".format(w) + lt
                                for w, lt in zip(weights, losses_txt)
                            ]
                        str_out += " ll " + " ".join(losses_txt)

                        if time_last:
                            dt_iter = (time.time() -
                                       time_last) / len(losses_list)
                            dt_ep = dt_iter * steps_per_epoch
                            str_out += " it {:.1f}s".format(dt_iter)
                            str_out += " ep {:.1f}m".format(dt_ep / (60))
                            str_out += " eta {:.1f}h".format(
                                dt_ep * (num_train_epochs - epoch_fp) /
                                (60 * 60))
                        losses_list = []
                        time_last = time.time()
                        if rank in [-1, 0]:
                            logger.info(str_out)

                        if rank > -1:
                            #sync losses
                            loss_tensor = torch.tensor([loss],
                                                       device=args.device)
                            torch.distributed.all_reduce(
                                loss_tensor, op=torch.distributed.ReduceOp.SUM)
                            loss = loss_tensor.item() / world_size

                        if restore_loss is None or loss < restore_loss * 1.5:
                            #good result lets save it
                            restore_loss = loss

                            if rank in [-1, 0]:
                                torch.save(
                                    {
                                        'model_state_dict':
                                        model_s.state_dict(),
                                        'optimizer_state_dict':
                                        optimizer.state_dict()
                                    }, restore_file)
                            if rank > -1:
                                torch.distributed.barrier()
                        else:
                            #bad result lets restore
                            restore_count += 1
                            logger.info(
                                "rank {} restore #{} from {} with {} loss".
                                format(rank, restore_count, restore_file,
                                       restore_loss))
                            checkpoint = torch.load(restore_file)
                            model_s.load_state_dict(
                                checkpoint['model_state_dict'])
                            #optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                            optimizer = new_optimizer()
                            switch_to_train(rank, model_s)

        if loss_num <= 0:
            if rank in [-1, 0]:
                check_point_name = 'checkpoint-{:02}'.format(train_count)
                save_model(args, model_s, tokenizer, check_point_name)
                check_point_name = check_point_name + '-{:02}'.format(epoch +
                                                                      1)
                switch_to_eval(rank, model_s)
                result_s = evaluate(args, model_s, test_dataset)
                for k, v in result_s.items():
                    logger.info("{} {} {}".format(check_point_name, k, v))
            if rank > -1:
                torch.distributed.barrier()

    if rank in [-1, 0]:
        if os.path.exists(restore_file):
            os.remove(restore_file)
Ejemplo n.º 3
0
def train(rank, args, model, model_t, train_dataset_qc, test_dataset_qc,
          fq_tune_only, model_controller):
    """ Train the model """
    global train_count
    train_count += 1

    world_size = 1 if rank < 0 else torch.distributed.get_world_size()

    if rank in [-1, 0]:
        printlog("Train model", train_count)
        printlog(model)

    q_dataset = train_dataset_qc.q_dataset

    per_gpu_train_batch_size = args.per_gpu_train_batch_size
    train_batch_size = per_gpu_train_batch_size * world_size

    if fq_tune_only:
        gradient_accumulation_steps = 1
        num_train_epochs = 1
    else:
        gradient_accumulation_steps = args.total_train_batch_size // train_batch_size
        num_train_epochs = args.num_train_epochs

    if rank < 0:
        #single process take all
        q_sampler = RandomSampler(q_dataset)
        q_dataloader = DataLoader(q_dataset,
                                  sampler=q_sampler,
                                  batch_size=train_batch_size,
                                  num_workers=4)
    else:
        #special sampler that divide samples between processes
        q_sampler = torch.utils.data.distributed.DistributedSampler(q_dataset,
                                                                    rank=rank)
        q_dataloader = DataLoader(q_dataset,
                                  sampler=q_sampler,
                                  batch_size=per_gpu_train_batch_size)

    steps_total = int(
        len(q_dataloader) // gradient_accumulation_steps * num_train_epochs)

    # Prepare optimizer and schedule
    named_params, groups = utils.make_param_groups(
        rank,
        model,
        args.
        freeze_list,  #list or str with subnames to define frozen parameters
        args.learning_rate,  #learning rate for no FQ parameters
        0.01,  # learning rate for FQ parameters
        fq_tune_only,  #true if only FQ parameters will be optimized
        model_controller)

    optimizer = AdamW(groups, eps=1e-08, lr=args.learning_rate, weight_decay=0)

    def lr_lambda(current_step):
        p = float(current_step) / float(steps_total)
        return 1 - p

    scheduler = LambdaLR(optimizer, lr_lambda)

    if rank in [-1, 0]:
        for n, p in named_params:
            printlog('param for tune', n)
        printlog("fq_tune_only", fq_tune_only)
        printlog("dataset size", len(q_dataset))
        printlog("epoches", num_train_epochs)
        printlog("per_gpu_train_batch_size", per_gpu_train_batch_size)
        printlog("n_gpu", args.n_gpu)
        printlog("world_size", world_size)
        printlog("gradient_accumulation_steps", gradient_accumulation_steps)
        printlog("total train batch size",
                 train_batch_size * gradient_accumulation_steps)
        printlog("steps_total", steps_total)

    global_step = 1
    model.zero_grad()
    indicators = collections.defaultdict(list)

    softplus = torch.nn.Softplus()

    loss_cfg = dict([t.split(':') for t in args.loss_cfg.split(',')])

    hnm_hist = {}

    for epoch in range(math.ceil(num_train_epochs)):
        indicators = collections.defaultdict(list)
        model.train()
        if model_t:
            model_t.train()
        if rank > -1:
            #set epoch to make different samples division betwen process for different epoches
            q_sampler.set_epoch(epoch)

        utils.sync_models(rank, model)
        for step, q_batch in enumerate(q_dataloader):
            epoch_fp = epoch + step / len(q_dataloader)
            if epoch_fp > num_train_epochs:
                break

            losses = []

            context_ids_pos = q_batch[3]
            q_inputs = get_inputs(q_batch, args.device)
            q_outputs = model(**q_inputs,
                              output_hidden_states=(model_t is not None))
            q_vec = q_outputs[0]

            #get positive embeddings
            c_batch = train_dataset_qc.c_dataset[context_ids_pos.detach().data]
            c_inputs = get_inputs(c_batch, args.device)
            c_outputs = model(**c_inputs,
                              output_hidden_states=(model_t is not None))
            c_vec_pos = c_outputs[0]

            if model_t is not None:
                q_emb_s, q_hidden_s = q_outputs
                c_emb_s, c_hidden_s = c_outputs
                with torch.no_grad():
                    q_emb_t, q_hidden_t = model_t(**q_inputs,
                                                  output_hidden_states=True)
                    c_emb_t, c_hidden_t = model_t(**c_inputs,
                                                  output_hidden_states=True)

                def align_and_loss_outputs(out_s, out_t):
                    if len(out_s) != len(out_t):
                        #the student and teacher outputs are not aligned. try to find teacher output for each student output
                        n_s, n_t = len(out_s), len(out_t)
                        out_t = [
                            out_t[(i * (n_t - 1)) // (n_s - 1)]
                            for i in range(n_s)
                        ]
                    assert len(out_s) == len(
                        out_t
                    ), "can not align number of outputs between student and teacher"
                    assert all(
                        s[0] == s[1]
                        for s in zip(out_s[0].shape, out_t[0].shape)
                    ), "output shapes for student and teacher are not the same"
                    return [(s - t.detach()).pow(2).mean()
                            for s, t in zip(out_s, out_t)]

                l_q = align_and_loss_outputs(q_hidden_s, q_hidden_t)
                l_c = align_and_loss_outputs(c_hidden_s, c_hidden_t)

                emb_loss = loss_cfg.get('emb_loss', '')
                if emb_loss == 'L2':
                    l_q.append((q_emb_s - q_emb_t.detach()).pow(2).mean())
                    l_c.append((c_emb_s - c_emb_t.detach()).pow(2).mean())
                elif emb_loss == 'L1':
                    l_q.append((q_emb_s - q_emb_t.detach()).abs().mean())
                    l_c.append((c_emb_s - c_emb_t.detach()).abs().mean())
                elif emb_loss.lower() not in ['', 'none', '0', 'disable']:
                    raise Exception(
                        'emb_loss={} is unsupported'.format(emb_loss))

                losses.extend([args.supervision_weight * l for l in l_c + l_q])

            triplet_num = int(loss_cfg.get('triplet_num', 1))
            if fq_tune_only:
                triplet_num = 0

            if triplet_num > 0:
                #disable grad to select negatives
                with torch.no_grad():
                    hnm_scores = []
                    hnm_idxs = []

                    #check that current step has no HNM conext vector
                    if global_step not in hnm_hist and args.hnm_num > 0:
                        #generate the new one

                        if world_size > 1 and (args.hnm_num % world_size) != 0:
                            #aligh hnm_num per each replica
                            hnm_plus = world_size - (args.hnm_num % world_size)
                            args.hnm_num += hnm_plus
                            logger.warning(
                                "rank {} args.hnm_num increased by {} from {} to {} to be the same after division by {} replicas."
                                .format(rank, hnm_plus,
                                        args.hnm_num - hnm_plus, args.hnm_num,
                                        world_size))

                        # generate random contexts to calc embedding
                        context_ids_all = torch.randint(
                            low=0,
                            high=len(train_dataset_qc.c_dataset),
                            size=[args.hnm_num])

                        if rank < 0:  #single process take all
                            context_ids = context_ids_all
                        else:
                            #broadcast one sigle indicies to all processes
                            context_ids_all = context_ids_all.to(args.device)
                            torch.distributed.broadcast(context_ids_all, 0)
                            context_ids_all = context_ids_all.cpu()

                            #each process take only small part to calc embedding
                            s = ((rank + 0) * args.hnm_num) // world_size
                            e = ((rank + 1) * args.hnm_num) // world_size
                            context_ids = context_ids_all[s:e]

                        batch_size = min(args.hnm_batch_size,
                                         context_ids.shape[0])

                        s, e = 0, batch_size
                        c_outputs = []
                        while e > s:
                            idx = context_ids.detach()[s:e]
                            c_batch = train_dataset_qc.c_dataset[idx]
                            inputs = get_inputs(c_batch, args.device)
                            outputs = model(**inputs,
                                            output_hidden_states=False)
                            c_outputs.append(outputs[0])
                            s, e = e, min(e + batch_size, context_ids.shape[0])

                        context_emb = torch.cat(c_outputs, dim=0)

                        if rank < 0:
                            # single process calculated all
                            context_emb_all = context_emb
                        else:
                            context_emb_list = [
                                torch.zeros_like(context_emb)
                                for _ in range(world_size)
                            ]
                            torch.distributed.all_gather(
                                context_emb_list, context_emb)
                            context_emb_all = torch.cat(context_emb_list,
                                                        dim=0)

                        hnm_hist[global_step] = (context_ids_all,
                                                 context_emb_all)

                        #check history size and crop the oldest one
                        if len(hnm_hist) > args.hnm_hist_num:
                            del hnm_hist[min(hnm_hist.keys())]

                    #calc HNM scores for current question batch
                    for hist_step, (c_idx, c_vec) in hnm_hist.items():
                        w = args.hnm_hist_alpha**(global_step - hist_step)
                        t1 = q_vec[:, None, :]
                        t2 = c_vec[None, :, :]
                        d = (t1 - t2)
                        score = -d.norm(2, dim=-1)
                        score = score * w

                        hnm_scores.append(score)
                        hnm_idxs.append(c_idx)

                    if hnm_scores:
                        #choose the hardest negative if we have scores
                        score = torch.cat(hnm_scores, dim=-1)
                        idx = torch.cat(hnm_idxs, dim=-1)
                        score = score.cpu()
                        pos_mask = (context_ids_pos[:,
                                                    None] == idx[None, :]).to(
                                                        dtype=score.dtype,
                                                        device=score.device)
                        score = (1 - pos_mask) * score + pos_mask * score.min(
                        )  #make positive context with small score to avoid chose it as hard neg
                        hn_idx = score.argmax(dim=1, keepdim=True)

                        context_ids_neg = idx[hn_idx]
                    else:
                        #just random selection in case of no scores for HNM
                        size = (context_ids_pos.shape[0], 1)
                        context_ids_neg = torch.randint(
                            0,
                            len(train_dataset_qc.c_dataset) - 1, size)
                        shift = (context_ids_neg >= context_ids_pos[:, None])
                        context_ids_neg = context_ids_neg + shift.to(
                            dtype=context_ids_neg.dtype)

                d_pos = (q_vec - c_vec_pos).norm(2, dim=-1)
                # get negative embeddings and calc losses
                for neg_index in range(context_ids_neg.shape[1]):
                    ids = context_ids_neg[:, neg_index]
                    c_batch = train_dataset_qc.c_dataset[ids.detach()]
                    inputs = get_inputs(c_batch, args.device)

                    outputs = model(**inputs, output_hidden_states=False)
                    c_vec_neg = outputs[0]

                    for triplet_index in range(triplet_num):

                        if triplet_index == 0:
                            d_neg = (q_vec - c_vec_neg).norm(2, dim=-1)
                        if triplet_index == 1:
                            d_neg = (c_vec_pos - c_vec_neg).norm(2, dim=-1)

                        d_diff = d_pos - d_neg

                        indicators['dd' + str(triplet_index)].append(
                            [v.mean().item() for v in (d_pos, d_neg, d_diff)])

                        l = softplus(d_diff)
                        losses.append(l)

                        del d_neg
                del d_pos

                #average over batch
                losses = [l.mean() for l in losses]

            l = sum(losses) / len(losses)
            (l / gradient_accumulation_steps).backward()

            indicators['loss'].append(l.item())
            indicators['ll'].append([lll.item() for lll in losses])

            #del losses
            del l

            if (step + 1) % gradient_accumulation_steps == 0:

                utils.sync_grads(rank,
                                 named_params,
                                 report_no_grad_params=(global_step == 1))
                torch.nn.utils.clip_grad_norm_([p for n, p in named_params], 1)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if global_step % 10 == 0:
                    # Log metrics
                    wall_time = epoch + step / len(q_dataloader)

                    lrp = [
                        '{:.2f}'.format(i)
                        for i in np.log10(scheduler.get_last_lr())
                    ]

                    str_out = "{} ep {:.2f} lrp {}".format(
                        train_count, epoch_fp, " ".join(lrp))

                    for k, v in indicators.items():
                        v = np.array(v)
                        if len(v.shape) == 1:
                            v = v[:, None]

                        if rank > -1:
                            #sync indicators
                            vt = torch.tensor(v).to(args.device)
                            torch.distributed.all_reduce(
                                vt, op=torch.distributed.ReduceOp.SUM)
                            v = vt.cpu().numpy() / float(world_size)

                        str_out += " {} {}".format(
                            k,
                            " ".join(["{:.3f}".format(t) for t in v.mean(0)]))

                    if 'score' in locals():
                        str_out += " SS {}".format(list(score.shape))

                    if 'time_last' in locals():
                        dt_iter = (time.time() - time_last) / len(
                            indicators['loss'])
                        dt_ep = dt_iter * len(q_dataloader)
                        str_out += " it {:.1f}s".format(dt_iter)
                        str_out += " ep {:.1f}m".format(dt_ep / (60))
                        str_out += " eta {:.1f}h".format(
                            dt_ep * (num_train_epochs - epoch_fp) / (60 * 60))
                    time_last = time.time()

                    indicators = collections.defaultdict(list)
                    if rank in [-1, 0]:
                        logger.info(str_out)

        if rank in [-1, 0]:
            check_point_name = 'checkpoint-{:02}'.format(train_count)
            check_point_name = check_point_name + '-{:02}'.format(epoch + 1)
            result_s = evaluate(args, model.eval(), test_dataset_qc)
            for k, v in result_s.items():
                logger.info("{} {} {}".format(check_point_name, k, v))
        if rank > -1:
            torch.distributed.barrier()
Ejemplo n.º 4
0
    def _update_network(self, transitions):

        # pre-process the observation and goal
        o, o_next, g = transitions['obs'], transitions[
            'obs_next'], transitions['g']
        transitions['obs'], transitions['g'] = self._preproc_og(o, g)
        transitions['obs_next'], transitions['g_next'] = self._preproc_og(
            o_next, g)

        # start to do the update
        obs_norm = self.o_norm.normalize(transitions['obs'])
        g_norm = self.g_norm.normalize(transitions['g'])
        inputs_norm = np.concatenate([obs_norm, g_norm], axis=1)
        obs_next_norm = self.o_norm.normalize(transitions['obs_next'])
        g_next_norm = self.g_norm.normalize(transitions['g_next'])
        inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1)

        # transfer them into the tensor
        inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32)
        inputs_next_norm_tensor = torch.tensor(inputs_next_norm,
                                               dtype=torch.float32)
        actions_tensor = torch.tensor(transitions['actions'],
                                      dtype=torch.float32)
        r_tensor = torch.tensor(transitions['r'], dtype=torch.float32)

        if self.config['cuda']:
            inputs_norm_tensor = inputs_norm_tensor.cuda()
            inputs_next_norm_tensor = inputs_next_norm_tensor.cuda()
            actions_tensor = actions_tensor.cuda()
            r_tensor = r_tensor.cuda()

        # calculate the target Q value function
        with torch.no_grad():
            # do the normalization
            # concatenate the stuffs
            actions_next = self.actor_target_network(inputs_next_norm_tensor)
            q_next_value = self.critic_target_network(inputs_next_norm_tensor,
                                                      actions_next)
            q_next_value = q_next_value.detach()
            target_q_value = r_tensor + self.config['gamma'] * q_next_value
            target_q_value = target_q_value.detach()
            # clip the q value
            clip_return = 1 / (1 - self.config['gamma'])
            target_q_value = torch.clamp(target_q_value, -clip_return, 0)

        # the q loss
        real_q_value = self.critic_network(inputs_norm_tensor, actions_tensor)
        critic_loss = (target_q_value - real_q_value).pow(2).mean()

        # self.main.Q_tf ==> real_q_value
        # self.main.Q_pi_tf ==> self.critic_network(inputs_norm_tensor, actions_real) ==> approx_q_value

        # the actor loss
        action_l2 = self.config['action_l2']
        actions_real = self.actor_network(inputs_norm_tensor)
        approx_q_value = self.critic_network(inputs_norm_tensor, actions_real)

        if self.bc_loss:
            # train with demonstrations using behavior cloning

            # choose only the demo buffer samples
            b_size = self.config['batch_size']
            demo_b_size = self.config['demo_batch_size']
            mask = np.concatenate(
                (np.zeros(b_size - demo_b_size), np.ones(demo_b_size)), axis=0)
            mask = torch.tensor(mask,
                                dtype=torch.uint8,
                                device=actions_real.device)

            if self.q_filter:
                # use Q-filter trick to perform BC only when needed
                with torch.no_grad():
                    mask &= (real_q_value > approx_q_value).squeeze()

            prm_loss_weight = self.config['prm_loss_weight']
            cloning_loss = self.config['aux_loss_weight'] * (
                actions_real[mask] - actions_tensor[mask]).pow(2).sum()
        else:
            # train without demonstrations
            prm_loss_weight = 1.0
            cloning_loss = None

        actor_loss = -prm_loss_weight * approx_q_value.mean()
        actor_loss += prm_loss_weight * action_l2 * (
            actions_real / self.env_params['action_max']).pow(2).mean()

        if cloning_loss is not None:
            actor_loss += cloning_loss

        # update actor network
        self.actor_optim.zero_grad()
        actor_loss.backward()
        sync_grads(self.actor_network)
        self.actor_optim.step()

        # update critic network
        self.critic_optim.zero_grad()
        critic_loss.backward()
        sync_grads(self.critic_network)
        self.critic_optim.step()

        res = dict(actor_loss=actor_loss.item(),
                   critic_loss=critic_loss.item())
        if cloning_loss is not None:
            res['cloning_loss'] = cloning_loss.item()
        return res
    def train(self, epoch_start):

        global_step = 0
        self.check_loss_raise = CheckLossRaise()
        for epoch in range(epoch_start, math.ceil(self.args.num_train_epochs)):
            self.indicators = collections.defaultdict(list)

            utils.sync_models(self.rank, self.model)

            self.model.train()

            self.model.zero_grad()
            grad_count = 0

            if self.rank > -1:
                #set epoch to make different samples division betwen proceses for different epoches
                self.dataloader.sampler.set_epoch(epoch)

            for step, batch in enumerate(self.dataloader):
                epoch_fp = epoch + step/len(self.dataloader)
                if epoch_fp > self.args.num_train_epochs:
                    break

                x_noise, x_clean = [t.to(self.args.device) for t in batch]

                #augment and mix signals
                x_clean, x_noise, x = self.mix_signals(x_clean, x_noise)

                #forward pass
                y_clean, Y_clean, _ = self.model(x)

                #calc specter for clean input signal
                tail_size = self.model.wnd_length - self.model.hop_length
                X_clean = self.model.encode(torch.nn.functional.pad(x_clean, (tail_size, 0)))

                # crop target and model output to align to each other
                sample_ahead = self.model.get_sample_ahead()
                spectre_ahead = self.model.ahead
                if sample_ahead > 0:
                    x = x[:, :-sample_ahead]
                    x_clean = x_clean[:, :-sample_ahead]
                    y_clean = y_clean[:, sample_ahead:]
                if spectre_ahead > 0:
                    Y_clean = Y_clean[:, :, :, spectre_ahead:]
                    X_clean = X_clean[:, :, :, :-spectre_ahead]

                loss = self.loss(epoch_fp, y_clean, Y_clean, x_clean, X_clean)
                self.indicators['loss'].append(loss.item())

                #calculate and accumulate gradients
                loss.backward()
                grad_count += 1

                #continue if not all gradients were accumulated
                if grad_count < self.gradient_accumulation_steps:
                    continue

                #make optimization step
                utils.sync_grads(self.rank, self.named_params, global_step==0, grad_count)
                self.optimizer.step()  # make optimization step
                self.scheduler.step()  # Update learning rate schedule
                global_step += 1

                self.model.zero_grad()
                grad_count = 0

                #make logs only after several steps
                if global_step % self.args.logacc != 0:
                    continue

                #average indicator over GPUs and iterations
                self.aver_indicators()

                #check that negsisdr suddenly raise
                #if high raise detected then model parameters are restored and optimizer is reset
                self.check_loss_raise.check(
                    self.indicators_mean["negsisdr"],
                    self.named_params,
                    self.optimizer
                )

                self.log_indicators(epoch_fp)

                self.indicators = collections.defaultdict(list)

            self.save_and_eval_checkpoint(epoch+1)
Ejemplo n.º 6
0
def train(rank, args, model, model_t, train_dataset_qa, test_dataset_qa, scale_tune):
    """ Train the model """
    global train_count
    train_count += 1
    world_size = 1 if rank < 0 else torch.distributed.get_world_size()

    if rank in [-1, 0]:
        printlog("Train model",train_count)
        printlog(model)

    per_gpu_train_batch_size = args.per_gpu_train_batch_size
    train_batch_size = per_gpu_train_batch_size * world_size
    gradient_accumulation_steps = args.total_train_batch_size // train_batch_size
    num_train_epochs = args.num_train_epochs

    if scale_tune:
        gradient_accumulation_steps = 1
        num_train_epochs = 1

    if rank < 0:
        #single process take all samples
        sampler = RandomSampler(train_dataset_qa)
        dataloader = DataLoader(train_dataset_qa, sampler=sampler, batch_size=train_batch_size, num_workers=4)
    else:
        #special sampler that divide samples beween processes
        sampler = torch.utils.data.distributed.DistributedSampler(train_dataset_qa, rank=rank)
        dataloader = DataLoader(train_dataset_qa, sampler=sampler, batch_size=per_gpu_train_batch_size)

    steps_total = int(len(dataloader) // gradient_accumulation_steps * num_train_epochs)

    # Prepare optimizer and schedule
    freeze_list = args.freeze_list.split(',') if args.freeze_list else []
    named_params = []
    for n, p in model.named_parameters():
        if n.lower()!="none" and any(fn in n for fn in freeze_list):
            if rank in [-1, 0]:
                logger.warning("rank {} {} param is frozen and excluded from tune".format(rank,n))
            continue
        named_params.append( (n, p) )

    # split parameters to scale and the rest
    named_params_scale = [(n, p) for n, p in named_params if '.scale' in n]
    named_params_rest = [(n, p) for n, p in named_params if '.scale' not in n]

    if scale_tune:
        #keep only scale parameters
        named_params = named_params_scale
        named_params_rest = []

    groups = []
    if named_params_scale:
        groups.append({'params': [p for n, p in named_params_scale], 'lr': 0.01})
    if named_params_rest:
        groups.append({'params': [p for n, p in named_params_rest],  'lr': args.learning_rate})

    optimizer = AdamW(
        groups,
        eps=1e-08,
        lr=args.learning_rate,
        weight_decay=0)

    def lr_lambda(current_step):
        p = float(current_step) / float(steps_total)
        return 1 - p

    scheduler = LambdaLR(optimizer, lr_lambda)

    if rank in [-1, 0]:
        for n,p in named_params:
            printlog('param for tune',n)
        printlog("scale_tune", scale_tune )
        printlog("dataset size", len(train_dataset_qa) )
        printlog("epoches", num_train_epochs )
        printlog("per_gpu_train_batch_size", per_gpu_train_batch_size )
        printlog("n_gpu", args.n_gpu )
        printlog("world_size", world_size )
        printlog("gradient_accumulation_steps", gradient_accumulation_steps )
        printlog("total train batch size", train_batch_size * gradient_accumulation_steps )
        printlog("steps_total",steps_total )

    global_step = 0
    model.zero_grad()
    indicators = collections.defaultdict(list)

    softplus = torch.nn.Softplus()

    loss_cfg = dict([t.split(':') for t in args.loss_cfg.split(',')]) if args.loss_cfg else dict()

    for epoch in range(math.ceil(num_train_epochs)):
        indicators = collections.defaultdict(list)
        model.train()
        set_output_hidden_states(rank, model, (model_t is not None))
        utils.sync_models(rank, model)
        if model_t is not None:
            set_output_hidden_states(rank, model_t, True)
            model_t.train()
        if rank > -1:
            #set epoch to make different samples division betwen process for different epoches
            sampler.set_epoch(epoch)

        for step, batch in enumerate(dataloader):
            epoch_fp = epoch + step/len(dataloader)
            if epoch_fp > num_train_epochs:
                break

            epoch_fp = epoch + step/len(dataloader)

            losses = []

            inputs = get_inputs(batch, args.device)
            targets = get_targets(batch, args.device)
            outputs = model(**inputs, **targets, output_hidden_states=(model_t is not None))
            losses.append(outputs[0])
            outputs = outputs[1:]

            if model_t is not None:
                with torch.no_grad():
                    outputs_t = model_t(**inputs, output_hidden_states=True)
                    hidden_t = outputs_t[2]
                    assert isinstance(hidden_t, (tuple,list)), "hidden states output is not detected right"
                    assert len(hidden_t) == model_t.config.num_hidden_layers+1, "hidden states output is not detected right"

                if args.kd_weight>0:
                    # Calculate knowladge distilation loss
                    kd_losses = []
                    for logit_s,logit_t in zip(outputs[0:2],outputs_t[0:2]):
                        T = 1
                        prob_t = torch.nn.functional.softmax(logit_t.detach() / T, dim=1)
                        logprob_s = torch.nn.functional.log_softmax(logit_s / T, dim=1)
                        kd_losses.append( -(logprob_s * prob_t).mean() * (T * T * prob_t.shape[1]) )
                    losses.append(args.kd_weight*sum(kd_losses)/len(kd_losses))


                hidden_s = outputs[2]
                assert isinstance(hidden_s, (tuple,list)), "hidden states output is not detected right"
                assert len(hidden_s) == model.config.num_hidden_layers+1, "hidden states output is not detected right"

                def align_and_loss_outputs(out_s, out_t):
                    if len(out_s) != len(out_t):
                        #the student and teacher outputs are not aligned. try to find teacher output for each student output
                        n_s, n_t = len(out_s), len(out_t)
                        out_t = [out_t[(i*(n_t-1))//(n_s-1)] for i in range(n_s)]
                    assert len(out_s) == len(out_t), "can not align number of outputs between student and teacher"
                    assert all(s[0] == s[1] for s in zip(out_s[0].shape, out_t[0].shape)), "output shapes for student and teacher are not the same"
                    return [(s - t.detach()).pow(2).mean() for s,t in zip(out_s, out_t)]

                sw_losses = align_and_loss_outputs(hidden_s,hidden_t)

                losses.extend([args.supervision_weight*l for l in sw_losses])

            #average over batch
            losses = [l.mean() for l in losses]

            l = sum(losses)/len(losses)
            indicators['loss'].append(l.item())
            indicators['ll'].append([lll.item() for lll in losses])

            (l/gradient_accumulation_steps).backward()

            del l

            if (step + 1) % gradient_accumulation_steps == 0:
                global_step += 1

                utils.sync_grads(rank, named_params, report_no_grad_params=(global_step==1))
                torch.nn.utils.clip_grad_norm_([p for n, p in named_params], 1)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()


                if global_step % 50 == 0:
                    # Log metrics
                    wall_time = epoch + step / len(dataloader)

                    lrp = " ".join(['{:.2f}'.format(t) for t in np.log10(scheduler.get_last_lr())])
                    str_out = "{} ep {:.2f} lrp {}".format(train_count, epoch_fp, lrp)

                    for k,v in indicators.items():
                        v = np.array(v)
                        if len(v.shape)==1:
                            v = v[:,None]

                        if rank>-1:
                            #sync indicators
                            vt = torch.tensor(v).to(args.device)
                            torch.distributed.all_reduce(vt, op=torch.distributed.ReduceOp.SUM)
                            v = vt.cpu().numpy() / float(world_size)

                        str_out += " {} {}".format(k," ".join(["{:.3f}".format(t) for t in v.mean(0)]))


                    if 'time_last' in locals():
                        #estimate processing times
                        dt_iter = (time.time() - time_last) / len(indicators['loss'])
                        dt_ep = dt_iter * len(dataloader)
                        str_out += " it {:.1f}s".format(dt_iter)
                        str_out += " ep {:.1f}m".format(dt_ep / (60))
                        str_out += " eta {:.1f}h".format(dt_ep * (num_train_epochs - epoch_fp) / (60 * 60))
                    time_last = time.time()

                    indicators = collections.defaultdict(list)
                    if rank in [-1, 0]:
                        logger.info(str_out)

        if rank in [-1, 0]:
            check_point_name = 'checkpoint-{:02}'.format(train_count)
            check_point_name = check_point_name + '-{:02}'.format(epoch + 1)
            model.eval()
            set_output_hidden_states(rank, model, False)
            result_s = evaluate(args, model, test_dataset_qa)
            for k,v in result_s.items():
                logger.info("{} {} {}".format(check_point_name, k, result_s[k]))
        if rank>-1:
            torch.distributed.barrier()