Exemple #1
0
def train(epoch):
    e = epoch
    model.train()
    SDR_SUM = np.array([])

    if config.schedule:
        scheduler.step()
        print("Decaying learning rate to %g" % scheduler.get_lr()[0])
        if config.is_dis:
            scheduler_dis.step()
        lera.log({
            'lr': scheduler.get_lr()[0],
        })

    if opt.model == 'gated':
        model.current_epoch = epoch

    global e, updates, total_loss, start_time, report_total, total_loss_sgm, total_loss_ss
    if config.MLMSE:
        global Var

    train_data_gen = prepare_data('once', 'train')
    # for raw_src, src, src_len, raw_tgt, tgt, tgt_len in trainloader:
    while True:
        try:
            train_data = train_data_gen.next()
            if train_data == False:
                print 'SDR_aver_epoch:', SDR_SUM.mean()
                break  #如果这个epoch的生成器没有数据了,直接进入下一个epoch

            src = Variable(torch.from_numpy(train_data['mix_feas']))
            # raw_tgt = [spk.keys() for spk in train_data['multi_spk_fea_list']]
            raw_tgt = [
                sorted(spk.keys()) for spk in train_data['multi_spk_fea_list']
            ]
            feas_tgt = models.rank_feas(
                raw_tgt,
                train_data['multi_spk_fea_list'])  #这里是目标的图谱,aim_size,len,fre

            # 要保证底下这几个都是longTensor(长整数)
            tgt_max_len = config.MAX_MIX + 2  # with bos and eos.
            tgt = Variable(
                torch.from_numpy(
                    np.array([[0] + [dict_spk2idx[spk] for spk in spks] +
                              (tgt_max_len - len(spks) - 1) *
                              [dict_spk2idx['<EOS>']] for spks in raw_tgt],
                             dtype=np.int))).transpose(0,
                                                       1)  #转换成数字,然后前后加开始和结束符号。
            src_len = Variable(
                torch.LongTensor(config.batch_size).zero_() +
                mix_speech_len).unsqueeze(0)
            tgt_len = Variable(
                torch.LongTensor([
                    len(one_spk)
                    for one_spk in train_data['multi_spk_fea_list']
                ])).unsqueeze(0)
            if use_cuda:
                src = src.cuda().transpose(0, 1)
                tgt = tgt.cuda()
                src_len = src_len.cuda()
                tgt_len = tgt_len.cuda()
                feas_tgt = feas_tgt.cuda()

            model.zero_grad()
            # optim.optimizer.zero_grad()

            # aim_list 就是找到有正经说话人的地方的标号
            aim_list = (tgt[1:-1].transpose(0, 1).contiguous().view(-1) !=
                        dict_spk2idx['<EOS>']).nonzero().squeeze()
            aim_list = aim_list.data.cpu().numpy()

            outputs, targets, multi_mask = model(
                src, src_len, tgt, tgt_len,
                dict_spk2idx)  #这里的outputs就是hidden_outputs,还没有进行最后分类的隐层,可以直接用
            print 'mask size:', multi_mask.size()

            if 1 and len(opt.gpus) > 1:
                sgm_loss, num_total, num_correct = model.module.compute_loss(
                    outputs, targets, opt.memory)
            else:
                sgm_loss, num_total, num_correct = model.compute_loss(
                    outputs, targets, opt.memory)
            print 'loss for SGM,this batch:', sgm_loss.data[0] / num_total

            src = src.transpose(0, 1)
            # expand the raw mixed-features to topk_max channel.
            siz = src.size()
            assert len(siz) == 3
            topk_max = config.MAX_MIX  #最多可能的topk个数
            x_input_map_multi = torch.unsqueeze(src, 1).expand(
                siz[0], topk_max, siz[1],
                siz[2]).contiguous().view(-1, siz[1], siz[2])
            x_input_map_multi = x_input_map_multi[aim_list]
            multi_mask = multi_mask.transpose(0, 1)

            if 1 and len(opt.gpus) > 1:
                if config.MLMSE:
                    Var = model.module.update_var(x_input_map_multi,
                                                  multi_mask, feas_tgt)
                    lera.log_image(u'Var weight',
                                   Var.data.cpu().numpy().reshape(
                                       config.speech_fre, config.speech_fre,
                                       1).repeat(3, 2),
                                   clip=(-1, 1))
                    ss_loss = model.module.separation_loss(
                        x_input_map_multi, multi_mask, feas_tgt, Var)
                else:
                    ss_loss = model.module.separation_loss(
                        x_input_map_multi, multi_mask, feas_tgt)
            else:
                ss_loss = model.separation_loss(x_input_map_multi, multi_mask,
                                                feas_tgt)

            loss = sgm_loss + 5 * ss_loss
            # dis_loss model
            if config.is_dis:
                dis_loss = models.loss.dis_loss(config, topk_max, model_dis,
                                                x_input_map_multi, multi_mask,
                                                feas_tgt, func_dis)
                loss = loss + dis_loss
                # print 'dis_para',model_dis.parameters().next()[0]
                # print 'ss_para',model.parameters().next()[0]

            loss.backward()
            # print 'totallllllllllll loss:',loss
            total_loss_sgm += sgm_loss.data[0]
            total_loss_ss += ss_loss.data[0]
            lera.log({
                'sgm_loss': sgm_loss.data[0],
                'ss_loss': ss_loss.data[0],
                'loss:': loss.data[0],
            })

            if (updates % config.eval_interval) in [
                    0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
            ]:
                predicted_maps = multi_mask * x_input_map_multi
                # predicted_maps=Variable(feas_tgt)
                utils.bss_eval(config,
                               predicted_maps,
                               train_data['multi_spk_fea_list'],
                               raw_tgt,
                               train_data,
                               dst='batch_outputjaa')
                del predicted_maps, multi_mask, x_input_map_multi
                # raw_input('wait to continue......')
                sdr_aver_batch = bss_test.cal('batch_outputjaa/')
                lera.log({'SDR sample': sdr_aver_batch})
                SDR_SUM = np.append(SDR_SUM, sdr_aver_batch)
                print 'SDR_aver_now:', SDR_SUM.mean()

            total_loss += loss.data[0]
            report_total += num_total
            optim.step()
            if config.is_dis:
                optim_dis.step()

            updates += 1
            if updates % 30 == 0:
                logging(
                    "time: %6.3f, epoch: %3d, updates: %8d, train loss this batch: %6.3f,sgm loss: %6.6f,ss loss: %6.6f\n"
                    % (time.time() - start_time, epoch, updates, loss /
                       num_total, total_loss_sgm / 30.0, total_loss_ss / 30.0))
                total_loss_sgm, total_loss_ss = 0, 0

            # continue

            if 0 or updates % config.eval_interval == 0 and epoch > 1:
                logging(
                    "time: %6.3f, epoch: %3d, updates: %8d, train loss: %6.5f\n"
                    % (time.time() - start_time, epoch, updates,
                       total_loss / report_total))
                print('evaluating after %d updates...\r' % updates)
                # score = eval(epoch)
                for metric in config.metric:
                    scores[metric].append(score[metric])
                    lera.log({
                        'sgm_micro_f1': score[metric],
                    })
                    if metric == 'micro_f1' and score[metric] >= max(
                            scores[metric]):
                        save_model(log_path + 'best_' + metric +
                                   '_checkpoint.pt')
                    if metric == 'hamming_loss' and score[metric] <= min(
                            scores[metric]):
                        save_model(log_path + 'best_' + metric +
                                   '_checkpoint.pt')

                model.train()
                total_loss = 0
                start_time = 0
                report_total = 0

        except RuntimeError, eeee:
            print 'Erros here eeee: ', eeee
            continue
        except Exception, dddd:
            print '\n\n\nRare errors: ', dddd
            continue
Exemple #2
0
def train(epoch):
    e = epoch
    model.train()

    if config.schedule:
        scheduler.step()
        print("Decaying learning rate to %g" % scheduler.get_lr()[0])
        if config.is_dis:
            scheduler_dis.step()
        lera.log({
            'lr': scheduler.get_lr()[0],
        })

    if opt.model == 'gated':
        model.current_epoch = epoch

    global e, updates, total_loss, start_time, report_total, total_loss_sgm, total_loss_ss
    if config.MLMSE:
        global Var

    train_data_gen = prepare_data('once', 'train')
    # for raw_src, src, src_len, raw_tgt, tgt, tgt_len in trainloader:
    while True:
        train_data = train_data_gen.next()
        if train_data == False:
            break  #如果这个epoch的生成器没有数据了,直接进入下一个epoch

        src = Variable(torch.from_numpy(train_data['mix_feas']))
        # raw_tgt = [spk.keys() for spk in train_data['multi_spk_fea_list']]
        raw_tgt = [
            sorted(spk.keys()) for spk in train_data['multi_spk_fea_list']
        ]
        feas_tgt = models.rank_feas(
            raw_tgt, train_data['multi_spk_fea_list'])  #这里是目标的图谱

        # 要保证底下这几个都是longTensor(长整数)
        tgt = Variable(
            torch.from_numpy(
                np.array([[0] + [dict_spk2idx[spk]
                                 for spk in spks] + [dict_spk2idx['<EOS>']]
                          for spks in raw_tgt],
                         dtype=np.int))).transpose(0, 1)  #转换成数字,然后前后加开始和结束符号。
        src_len = Variable(
            torch.LongTensor(config.batch_size).zero_() +
            mix_speech_len).unsqueeze(0)
        tgt_len = Variable(
            torch.LongTensor(config.batch_size).zero_() +
            len(train_data['multi_spk_fea_list'][0])).unsqueeze(0)

        if config.WFM:
            tmp_size = feas_tgt.size()
            assert len(tmp_size) == 4
            feas_tgt_square = feas_tgt * feas_tgt
            feas_tgt_square_sum = torch.sum(feas_tgt_square,
                                            dim=1,
                                            keepdim=True).expand(tmp_size)
            WFM_mask = feas_tgt_square / (feas_tgt_square_sum + 1e-10)

        if use_cuda:
            src = src.cuda().transpose(0, 1)
            tgt = tgt.cuda()
            src_len = src_len.cuda()
            tgt_len = tgt_len.cuda()
            feas_tgt = feas_tgt.cuda()
            if config.WFM:
                WFM_mask = WFM_mask.cuda()
        try:
            model.zero_grad()
            # optim.optimizer.zero_grad()
            outputs, targets, multi_mask = model(
                src, src_len, tgt,
                tgt_len)  #这里的outputs就是hidden_outputs,还没有进行最后分类的隐层,可以直接用
            print 'mask size:', multi_mask.size()

            if 1 and len(opt.gpus) > 1:
                sgm_loss, num_total, num_correct = model.module.compute_loss(
                    outputs, targets, opt.memory)
            else:
                sgm_loss, num_total, num_correct = model.compute_loss(
                    outputs, targets, opt.memory)
            print 'loss for SGM,this batch:', sgm_loss.data[0] / num_total

            if config.unit_norm:  #outputs---[len+1,bs,2*d]
                assert not config.global_emb
                unit_dis = (outputs[0] * outputs[1]).sum(1)
                print 'unit_dis this batch:', unit_dis.data.cpu().numpy()
                unit_dis = torch.masked_select(unit_dis,
                                               unit_dis > config.unit_norm)
                if len(unit_dis) > 0:
                    unit_dis = unit_dis.mean()

            src = src.transpose(0, 1)
            # expand the raw mixed-features to topk channel.
            siz = src.size()
            assert len(siz) == 3
            topk = feas_tgt.size()[1]
            x_input_map_multi = torch.unsqueeze(src, 1).expand(
                siz[0], topk, siz[1], siz[2])
            multi_mask = multi_mask.transpose(0, 1)

            if config.WFM:
                feas_tgt = x_input_map_multi.data * WFM_mask
            if 1 and len(opt.gpus) > 1:
                if config.MLMSE:
                    Var = model.module.update_var(x_input_map_multi,
                                                  multi_mask, feas_tgt)
                    lera.log_image(u'Var weight',
                                   Var.data.cpu().numpy().reshape(
                                       config.speech_fre, config.speech_fre,
                                       1).repeat(3, 2),
                                   clip=(-1, 1))
                    ss_loss = model.module.separation_loss(
                        x_input_map_multi, multi_mask, feas_tgt, Var)
                else:
                    ss_loss = model.module.separation_loss(
                        x_input_map_multi, multi_mask, feas_tgt)
            else:
                ss_loss = model.separation_loss(x_input_map_multi, multi_mask,
                                                feas_tgt)

            loss = sgm_loss + 5 * ss_loss
            if config.unit_norm and len(unit_dis):
                print 'unit_dis masked mean:', unit_dis.data[0]
                lera.log({
                    'unit_dis': unit_dis.data[0],
                })
                loss = loss + unit_dis
            if config.reID:
                print '#' * 30 + 'ReID part ' + '#' * 30
                predict_multi_map = multi_mask * x_input_map_multi
                predict_multi_map = predict_multi_map.view(
                    -1, mix_speech_len, speech_fre).transpose(0, 1)
                tgt_reID = []
                for spks in raw_tgt:
                    for spk in spks:
                        one_spk = [dict_spk2idx['<BOS>']] + [
                            dict_spk2idx[spk]
                        ] + [dict_spk2idx['<EOS>']]
                        tgt_reID.append(one_spk)
                tgt_reID = Variable(
                    torch.from_numpy(np.array(
                        tgt_reID, dtype=np.int))).transpose(0, 1).cuda()
                src_len_reID = Variable(
                    torch.LongTensor(topk * config.batch_size).zero_() +
                    mix_speech_len).unsqueeze(0).cuda()
                tgt_len_reID = Variable(
                    torch.LongTensor(topk * config.batch_size).zero_() +
                    1).unsqueeze(0).cuda()
                outputs_reID, targets_reID, multi_mask_reID = model(
                    predict_multi_map, src_len_reID, tgt_reID, tgt_len_reID
                )  #这里的outputs就是hidden_outputs,还没有进行最后分类的隐层,可以直接用
                if 1 and len(opt.gpus) > 1:
                    sgm_loss_reID, num_total_reID, _xx = model.module.compute_loss(
                        outputs_reID, targets_reID, opt.memory)
                else:
                    sgm_loss_reID, num_total_reID, _xx = model.compute_loss(
                        outputs_reID, targets_reID, opt.memory)
                print 'loss for SGM-reID mthis batch:', sgm_loss_reID.data[
                    0] / num_total_reID
                loss = loss + sgm_loss_reID
                if config.WFM:
                    feas_tgt = x_input_map_multi.data * WFM_mask
                if 1 and len(opt.gpus) > 1:
                    ss_loss_reID = model.module.separation_loss(
                        predict_multi_map.transpose(0, 1).unsqueeze(1),
                        multi_mask_reID.transpose(0, 1),
                        feas_tgt.view(-1, 1, mix_speech_len, speech_fre))
                else:
                    ss_loss_reID = model.separation_loss(
                        predict_multi_map.transpose(0, 1).unsqueeze(1),
                        multi_mask_reID.transpose(0, 1),
                        feas_tgt.view(-1, 1, mix_speech_len, speech_fre))
                loss = loss + ss_loss_reID
                print '#' * 30 + 'ReID part ' + '#' * 30

            # dis_loss model
            if config.is_dis:
                dis_loss = models.loss.dis_loss(config, topk, model_dis,
                                                x_input_map_multi, multi_mask,
                                                feas_tgt, func_dis)
                loss = loss + dis_loss
                # print 'dis_para',model_dis.parameters().next()[0]
                # print 'ss_para',model.parameters().next()[0]

            loss.backward()
            # print 'totallllllllllll loss:',loss
            total_loss_sgm += sgm_loss.data[0]
            total_loss_ss += ss_loss.data[0]
            lera.log({
                'sgm_loss': sgm_loss.data[0],
                'ss_loss': ss_loss.data[0],
            })
            if config.reID:
                lera.log({
                    'reID_sgm_loss': sgm_loss_reID.data[0],
                    'reID_ss_loss': ss_loss_reID.data[0],
                })
            total_loss += loss.data[0]
            report_total += num_total
            optim.step()
            if config.is_dis:
                optim_dis.step()

            updates += 1
            if updates % 30 == 0:
                logging(
                    "time: %6.3f, epoch: %3d, updates: %8d, train loss this batch: %6.3f,sgm loss: %6.6f,ss loss: %6.6f\n"
                    % (time.time() - start_time, epoch, updates, loss /
                       num_total, total_loss_sgm / 30.0, total_loss_ss / 30.0))
                total_loss_sgm, total_loss_ss = 0, 0

            # continue

            if 0 or updates % config.eval_interval == 0 and epoch > 1:
                logging(
                    "time: %6.3f, epoch: %3d, updates: %8d, train loss: %6.5f\n"
                    % (time.time() - start_time, epoch, updates,
                       total_loss / report_total))
                print('evaluating after %d updates...\r' % updates)
                score = eval(epoch)
                for metric in config.metric:
                    scores[metric].append(score[metric])
                    if metric == 'micro_f1' and score[metric] >= max(
                            scores[metric]):
                        save_model(log_path + 'best_' + metric +
                                   '_checkpoint.pt')
                    if metric == 'hamming_loss' and score[metric] <= min(
                            scores[metric]):
                        save_model(log_path + 'best_' + metric +
                                   '_checkpoint.pt')

                model.train()
                total_loss = 0
                start_time = 0
                report_total = 0
        except RuntimeError, info:
            print '**************Error occurs here************:', info
            continue

        if updates % config.save_interval == 1:
            save_model(log_path + 'TDAA2019_{}.pt'.format(updates))
Exemple #3
0
def train(epoch):
    e = epoch
    model.train()

    if config.schedule:
        scheduler.step()
        print("Decaying learning rate to %g" % scheduler.get_lr()[0])
        if config.is_dis:
            scheduler_dis.step()
        lera.log({
            'lr': scheduler.get_lr()[0],
        })

    if opt.model == 'gated':
        model.current_epoch = epoch

    global e, updates, total_loss, start_time, report_total, total_loss_sgm, total_loss_ss
    if config.MLMSE:
        global Var

    train_data_gen = prepare_data('once', 'train')
    # for raw_src, src, src_len, raw_tgt, tgt, tgt_len in trainloader:
    while True:
        train_data = train_data_gen.next()
        if train_data == False:
            break  #如果这个epoch的生成器没有数据了,直接进入下一个epoch

        src = Variable(torch.from_numpy(train_data['mix_feas']))
        # raw_tgt = [spk.keys() for spk in train_data['multi_spk_fea_list']]
        raw_tgt = [
            sorted(spk.keys()) for spk in train_data['multi_spk_fea_list']
        ]
        feas_tgt = models.rank_feas(
            raw_tgt, train_data['multi_spk_fea_list'])  #这里是目标的图谱

        # 要保证底下这几个都是longTensor(长整数)
        tgt = Variable(
            torch.from_numpy(
                np.array([[0] + [dict_spk2idx[spk]
                                 for spk in spks] + [dict_spk2idx['<EOS>']]
                          for spks in raw_tgt],
                         dtype=np.int))).transpose(0, 1)  #转换成数字,然后前后加开始和结束符号。
        src_len = Variable(
            torch.LongTensor(config.batch_size).zero_() +
            mix_speech_len).unsqueeze(0)
        tgt_len = Variable(
            torch.LongTensor(config.batch_size).zero_() +
            len(train_data['multi_spk_fea_list'][0])).unsqueeze(0)
        if use_cuda:
            src = src.cuda().transpose(0, 1)
            tgt = tgt.cuda()
            src_len = src_len.cuda()
            tgt_len = tgt_len.cuda()
            feas_tgt = feas_tgt.cuda()

        model.zero_grad()
        # optim.optimizer.zero_grad()
        outputs, targets, multi_mask = model(
            src, src_len, tgt,
            tgt_len)  #这里的outputs就是hidden_outputs,还没有进行最后分类的隐层,可以直接用
        print 'mask size:', multi_mask.size()

        if 1 and len(opt.gpus) > 1:
            sgm_loss, num_total, num_correct = model.module.compute_loss(
                outputs, targets, opt.memory)
        else:
            sgm_loss, num_total, num_correct = model.compute_loss(
                outputs, targets, opt.memory)
        print 'loss for SGM,this batch:', sgm_loss.data[0] / num_total

        src = src.transpose(0, 1)
        # expand the raw mixed-features to topk channel.
        siz = src.size()
        assert len(siz) == 3
        topk = feas_tgt.size()[1]
        x_input_map_multi = torch.unsqueeze(src,
                                            1).expand(siz[0], topk, siz[1],
                                                      siz[2])
        multi_mask = multi_mask.transpose(0, 1)

        if 1 and len(opt.gpus) > 1:
            if config.MLMSE:
                Var = model.module.update_var(x_input_map_multi, multi_mask,
                                              feas_tgt)
                lera.log_image(u'Var weight',
                               Var.data.cpu().numpy().reshape(
                                   config.speech_fre, config.speech_fre,
                                   1).repeat(3, 2),
                               clip=(-1, 1))
                ss_loss = model.module.separation_loss(x_input_map_multi,
                                                       multi_mask, feas_tgt,
                                                       Var)
            else:
                ss_loss = model.module.separation_loss(x_input_map_multi,
                                                       multi_mask, feas_tgt)
        else:
            ss_loss = model.separation_loss(x_input_map_multi, multi_mask,
                                            feas_tgt)

        loss = sgm_loss + ss_loss
        # dis_loss model
        if config.is_dis:
            dis_loss = models.loss.dis_loss(config, topk, model_dis,
                                            x_input_map_multi, multi_mask,
                                            feas_tgt, func_dis)
            loss = loss + dis_loss
            # print 'dis_para',model_dis.parameters().next()[0]
            # print 'ss_para',model.parameters().next()[0]

        loss.backward()
        # print 'totallllllllllll loss:',loss
        total_loss_sgm += sgm_loss.data[0]
        total_loss_ss += ss_loss.data[0]
        lera.log({
            'sgm_loss': sgm_loss.data[0],
            'ss_loss': ss_loss.data[0],
        })
        total_loss += loss.data[0]
        report_total += num_total
        optim.step()
        if config.is_dis:
            optim_dis.step()

        updates += 1
        if updates % 30 == 0:
            logging(
                "time: %6.3f, epoch: %3d, updates: %8d, train loss this batch: %6.3f,sgm loss: %6.6f,ss loss: %6.6f\n"
                % (time.time() - start_time, epoch, updates, loss / num_total,
                   total_loss_sgm / 30.0, total_loss_ss / 30.0))
            total_loss_sgm, total_loss_ss = 0, 0

        # continue

        if 0 or updates % config.eval_interval == 0 and epoch > 1:
            logging(
                "time: %6.3f, epoch: %3d, updates: %8d, train loss: %6.5f\n" %
                (time.time() - start_time, epoch, updates,
                 total_loss / report_total))
            print('evaluating after %d updates...\r' % updates)
            score = eval(epoch)
            for metric in config.metric:
                scores[metric].append(score[metric])
                if metric == 'micro_f1' and score[metric] >= max(
                        scores[metric]):
                    save_model(log_path + 'best_' + metric + '_checkpoint.pt')
                if metric == 'hamming_loss' and score[metric] <= min(
                        scores[metric]):
                    save_model(log_path + 'best_' + metric + '_checkpoint.pt')

            model.train()
            total_loss = 0
            start_time = 0
            report_total = 0

        if updates % config.save_interval == 1:
            save_model(log_path +
                       'checkpoint_v2_withdis{}.pt'.format(config.is_dis))
Exemple #4
0
def train(epoch, step):
    #lera.log('epoch', epoch)
    epoch += 1

    for input, _ in DataLoader(datasets[dataset], batch_size=batch_size, pin_memory=use_cuda, num_workers=2, shuffle=True, drop_last=True):
        if use_cuda:
            input = input.cuda()

        step += 1

        ze = enc(V(input))

        index = min_dist(V(ze.data), embeddings)
        sz = index.size()

        zq = (embeddings[index.view(-1)]       # [batch_size * x * x, D] containing vectors from embeddings
                .view(sz[0], sz[1], sz[2], D)  # [batch_size, x, x, D] 
                .permute(0, 3, 1, 2))          # [batch_size, D, x, x]

        emb_loss = (zq - V(ze.data)).pow(2).sum(1).mean() + 1e-2 * embeddings.pow(2).mean()

        # detach zq so it won't backprop to embeddings with recon loss
        zq = V(zq.data, requires_grad=True)

        output = dec(zq)

        commit_loss = beta * (ze - V(zq.data)).pow(2).sum(1).mean()
        recon_loss = F.mse_loss(output, V(input))

        optimizer.zero_grad()

        commit_loss.backward(retain_graph=True)
        emb_loss.backward()
        recon_loss.backward()

        # pass data term gradient from decoder to encoder
        ze.backward(zq.grad)

        optimizer.step()

        emb_count[index.data.view(-1)] = 1
        emb_count.sub_(0.01).clamp_(min=0)
        unique_embeddings = emb_count.gt(0).sum()

        sensitivity.add_(emb_loss.data[0] * (K - unique_embeddings) / K)
        sensitivity[emb_count.gt(0)] = 0

        lera.log({ 
            'recon_loss': recon_loss.data[0],
            'commit_loss': commit_loss.data[0],
            'unique_embeddings': emb_count.gt(0).sum(),
            }, console=True)

        # make comparison image
        if lera.every(seconds=60):
            input = input.cpu()[0:8,:,:,:]
            w = input.size(-1)
            output = output.data.cpu()[0:8,:,:,:]
            result = (torch.stack([input, output])           # [2, 8, 3, w, w]
                        .transpose(0, 1).contiguous()        # [8, 2, 3, w, w]
                        .view(4, 4, 3, w, w)                 # [4, 4, 3, w, w]
                        .permute(0, 3, 1, 4, 2).contiguous() # [4, w, 4, w, 3]
                        .view(w * 4, w * 4, 3))              # [w * 4, w * 4, 3]
            lera.log_image('reconstruction', result.numpy(), clip=(0, 1))

    # continue training
    if step < total_steps:
        train(epoch, step)