Exemplo n.º 1
0
def train(args, train_dataset, model_vae, encoder_tokenizer, decoder_tokenizer, table_name):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)


    # model_encoder, model_decoder, model_connector = model_vae.encoder,  model_vae.decoder, model_vae.linear
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model_vae.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in model_vae.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)


    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model_vae, optimizer = amp.initialize(model_vae, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model_vae = torch.nn.DataParallel(model_vae, device_ids=range(args.n_gpu)).to(args.device)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model_vae = torch.nn.parallel.DistributedDataParallel(model_vae, device_ids=[args.local_rank],
                                                          output_device=args.local_rank,
                                                          find_unused_parameters=True)

    model_vae = model_vae.module if hasattr(model_vae, 'module') else model_vae  # Take care of distributed/parallel training   
    
    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
                   args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0


    model_vae.zero_grad()

    train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])

    n_iter = int(args.num_train_epochs) * len(train_dataloader)
    beta_t_list = frange_cycle_zero_linear(n_iter, start=0.0, stop=args.beta,  n_cycle=1, ratio_increase=args.ratio_increase, ratio_zero=args.ratio_zero)

    tmp_list = []
    set_seed(args)  # Added here for reproducibility (even between python 2 and 3)
    for epoch in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            tokenized_text0, tokenized_text1, tokenized_text_lengths = batch
            # tokenized_text0 = tokenized_text0.to(args.device)
            # tokenized_text1 = tokenized_text1.to(args.device)
            # prepare input-output data for reconstruction

            # pdb.set_trace()
            max_len_values, _ = tokenized_text_lengths.max(0)
            tokenized_text0 = tokenized_text0[:,:max_len_values[0]]
            tokenized_text1 = tokenized_text1[:,:max_len_values[1]]

            inputs, labels = mask_tokens(tokenized_text0, encoder_tokenizer, args) if args.mlm else (tokenized_text0, tokenized_text1)
            labels = tokenized_text1

            tokenized_text1 = tokenized_text1.to(args.device)
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)

            model_vae.train()

            beta_t = beta_t_list[step +  epoch*len(epoch_iterator)]
            model_vae.args.beta = beta_t

            if beta_t == 0.0:
                model_vae.args.fb_mode = 0
            else:
                model_vae.args.fb_mode = 1
            
            if args.use_deterministic_connect:
                model_vae.args.fb_mode = 2

            loss_rec, loss_kl, loss = model_vae(inputs, labels)
            # pdb.set_trace()
            
            # Chunyuan: loss_rec size is [4], while latent_z size is [12]
            if args.n_gpu > 1:
                loss_rec = loss_rec.mean()  # mean() to average on multi-gpu parallel training
                loss_kl = loss_kl.mean()
                loss = loss.mean()

            if args.use_philly:
                print("PROGRESS: {}%".format(round(100 * (step +  epoch*len(epoch_iterator) ) /(int(args.num_train_epochs) *  len(epoch_iterator)) , 4))) 
                print("EVALERR: {}%".format(loss_rec)) 

            epoch_iterator.set_description(
                (
                    f'iter: {step +  epoch*len(epoch_iterator) }; loss: {loss.item():.3f}; '
                    f'loss_rec: {loss_rec.item():.3f}; loss_kl: {loss_kl.item():.3f}; '
                    f'beta: {model_vae.args.beta:.3f}'
                )
            )

            if global_step % 5 == 0:
                row = {
                        'PartitionKey': 'MILU_Rule_Rule_Template',
                        'RowKey': str(datetime.now()),
                        'ExpName' : args.ExpName, 
                        'iter': str( step +  epoch*len(epoch_iterator) ),
                        'loss': str( loss.item()),
                        'loss_rec': str(loss_rec.item()),
                        'loss_kl': str(loss_kl.item()),
                        'beta': str(model_vae.args.beta)
                    }
                # pdb.set_trace()
                ts.insert_entity(table_name, row)

            # pdb.set_trace()

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()                                   
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model_vae.parameters(), args.max_grad_norm)

                optimizer.step()

                scheduler.step()  # Update learning rate schedule

                model_vae.zero_grad()

                global_step += 1


                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model_vae, encoder_tokenizer, decoder_tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
                    tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    
                    # Save encoder model checkpoint
                    output_encoder_dir = os.path.join(args.output_dir, 'checkpoint-encoder-{}'.format(global_step))

                    if not os.path.exists(output_encoder_dir):
                        os.makedirs(output_encoder_dir)

                    model_encoder_to_save = model_vae.module.encoder if hasattr(model_vae, 'module') else model_vae.encoder  # Take care of distributed/parallel training
                    if args.use_philly:
                        save_solid = False
                        while not save_solid:
                            try:
                                model_encoder_to_save.save_pretrained(output_encoder_dir)
                                torch.save(args, os.path.join(output_encoder_dir, 'training_args.bin'))
                                logger.info("Saving model checkpoint to %s", output_encoder_dir)
                                save_solid = True
                            except:
                                pass
                    else:
                        model_encoder_to_save.save_pretrained(output_encoder_dir)
                        torch.save(args, os.path.join(output_encoder_dir, 'training_args.bin'))
                        logger.info("Saving model checkpoint to %s", output_encoder_dir)

                    # Save decoder model checkpoint
                    output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))

                    if not os.path.exists(output_decoder_dir):
                        os.makedirs(output_decoder_dir)

                    model_decoder_to_save = model_vae.module.decoder if hasattr(model_vae, 'module') else model_vae.decoder  # Take care of distributed/parallel training
                    if args.use_philly:
                        save_solid = False
                        while not save_solid:
                            try:
                                model_decoder_to_save.save_pretrained(output_decoder_dir)
                                torch.save(args, os.path.join(output_decoder_dir, 'training_args.bin'))
                                logger.info("Saving model checkpoint to %s", output_decoder_dir)
                                save_solid = True
                            except:
                                pass
                    else:
                        model_decoder_to_save.save_pretrained(output_decoder_dir)
                        torch.save(args, os.path.join(output_decoder_dir, 'training_args.bin'))
                        logger.info("Saving model checkpoint to %s", output_decoder_dir)


            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break

            
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
def train(args, train_dataloader, model_vae, encoder_tokenizer, decoder_tokenizer, table_name):
    """ Train the model """
    #gpus = list(gpu_indices())

    if args.local_rank in [-1, 0]: tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    # train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    # train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps #* args.num_train_epochs

    if args.distributed:
        t_total = t_total // ompi_size()

    # Prepare optimizer and schedule (linear warmup and decay)


    # model_encoder, model_decoder, model_connector = model_vae.encoder,  model_vae.decoder, model_vae.linear
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model_vae.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in model_vae.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)


    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model_vae, optimizer = amp.initialize(model_vae, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    #if args.n_gpu > 1:
    #    model_vae = torch.nn.DataParallel(model_vae, device_ids=range(args.n_gpu)).to(args.device)

    # Distributed training (should be after apex fp16 initialization)
    #if args.local_rank != -1:
        #model_vae = torch.nn.parallel.DistributedDataParallel(model_vae, device_ids=gpus, output_device=args.local_rank, find_unused_parameters=True)
        #model_vae = torch.nn.parallel.DistributedDataParallel(model_vae, device_ids=gpus)


    files = Path(args.train_data_file)
    num_files = len(list(files.glob('*seq64*.json')))


    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num files = %d", num_files)
    logger.info("  Num examples of first file = %d", train_dataloader.num_examples)
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
                   args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)


    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0

    model_vae.zero_grad()
    num_train_epochs_iterator = trange(int(args.num_train_epochs), desc="Epoch") #, disable=args.local_rank not in [-1, 0])

    #n_iter = int(args.num_train_epochs) * len(train_dataloader)
    #beta_t_list = frange_cycle_zero_linear(n_iter, start=0.0, stop=args.beta,  n_cycle=1, ratio_increase=args.ratio_increase, ratio_zero=args.ratio_zero)
    n_iter_per_file = len(train_dataloader) / args.per_gpu_train_batch_size
    n_iter = int(args.num_train_epochs * n_iter_per_file * num_files)
    beta_t_list = frange_cycle_zero_linear(n_iter, start=0.0, stop=args.beta, n_cycle=10,  ratio_increase=args.ratio_increase, ratio_zero=args.ratio_zero)
    beta_t = 0.0

    pdb.set_trace()
    tmp_list = []
    dict_token_length = defaultdict(int)

    set_seed(args)  # Added here for reproducibility (even between python 2 and 3)
    for epoch in range(int(args.num_train_epochs)): # num_train_epochs_iterator:
        train_dataloader.reset()
        for idx_file in range(num_files-1):
            logger.info(f"Rank {ompi_rank()}, Epoch {epoch}, File idx {train_dataloader.file_idx}")
            #epoch_iterator = tqdm(train_dataloader, desc="Iteration") #disable=disable=args.local_rank not in [-1, 0])
            for step, batch in enumerate(train_dataloader):
                tokenized_text0, tokenized_text1, tokenized_text_lengths = batch
                
                #dict_token_length[tokenized_text_lengths[0,0].item()] += 1
                
                # continue
                # tokenized_text0 = tokenized_text0.to(args.device)
                # tokenized_text1 = tokenized_text1.to(args.device)
                # prepare input-output data for reconstruction

                inputs, labels = mask_tokens(tokenized_text0, encoder_tokenizer, args) if args.mlm else (tokenized_text0, tokenized_text1)
                labels = tokenized_text1

                tokenized_text1 = tokenized_text1.to(args.device)
                inputs = inputs.to(args.device)
                labels = labels.to(args.device)

                model_vae.train()

                if args.use_beta_schedule:
                    if global_step >= len(beta_t_list):
                        beta_t = 1.0
                    else:
                        beta_t = beta_t_list[global_step]

                    #try:
                    #    beta_t = beta_t_list[global_step] #[step + idx_file* n_iter_per_file]
                    #except:
                    #    beta_t = 0.0

                #beta_t = 0.0 # beta_t_list[step +  epoch*len(epoch_iterator)]
                model_vae.module.args.beta = beta_t

                if beta_t == 0.0:
                    model_vae.module.args.fb_mode = 0
                else:
                    model_vae.module.args.fb_mode = 1
                
                if args.use_deterministic_connect:
                    model_vae.module.args.fb_mode = 2

                loss_rec, loss_kl, loss = model_vae(inputs, labels)

                loss_rec = loss_rec.mean()  # mean() to average on multi-gpu parallel training
                loss_kl = loss_kl.mean()
                loss = loss.mean()

                if args.use_philly:
                    #if args.local_rank in [-1, 0]:
                    if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                        logger.info("Steps {}, Rank {}, File {}, Epoch: [{}/{}][{}/{}], Beta: {}, Loss: {}".format(global_step, ompi_rank(), train_dataloader.file_idx,
                                    epoch, args.num_train_epochs, step, len(train_dataloader), model_vae.module.args.beta, loss_rec))
                        logger.info("PROGRESS: {}%".format(round(100*(step + epoch*len(train_dataloader))/(int(args.num_train_epochs) * len(train_dataloader)), 4)))
                        logger.info("EVALERR: {}%".format(loss_rec))
                        
                #epoch_iterator.set_description(
                #    (
                #        f'rank: {ompi_rank()}; '
                #       f'iter: {step +  epoch*len(epoch_iterator) }; file:{idx_file}; loss: {loss.item():.3f}; '
                #        f'loss_rec: {loss_rec.item():.3f}; loss_kl: {loss_kl.item():.3f}; '
                #        f'beta: {model_vae.module.args.beta:.3f}'
                #    )
                #)
                # if global_step % 5 == 0:
                #     row = {
                #             'PartitionKey': 'MILU_Rule_Rule_Template',
                #             'RowKey': str(datetime.now()),
                #             'ExpName' : args.ExpName, 
                #             'iter': str( step +  epoch*len(epoch_iterator) ),
                #             'loss': str( loss.item()),
                #             'loss_rec': str(loss_rec.item()),
                #             'loss_kl': str(loss_kl.item()),
                #             'beta': str(model_vae.args.beta)
                #         }
                #     # pdb.set_trace()
                #     ts.insert_entity(table_name, row)

                # pdb.set_trace()

                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()                                   
                else:
                    loss.backward()

                tr_loss += loss.item()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model_vae.parameters(), args.max_grad_norm)

                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    model_vae.zero_grad()

                    global_step += 1

                    if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                        # Log metrics
                        if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                            results = evaluate(args, model_vae, encoder_tokenizer, decoder_tokenizer)
                            for key, value in results.items():
                                tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
                        tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                        tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
                        logging_loss = tr_loss

                    if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                        
                        # Save encoder model checkpoint
                        output_encoder_dir = os.path.join(args.output_dir, 'checkpoint-encoder-{}-{}'.format(global_step, model_vae.module.args.beta))

                        if not os.path.exists(output_encoder_dir):
                            os.makedirs(output_encoder_dir)

                        model_encoder_to_save = model_vae.module.encoder if hasattr(model_vae, 'module') else model_vae.encoder  # Take care of distributed/parallel training
                        if args.use_philly:
                            save_solid = False
                            while not save_solid:
                                try:
                                    model_encoder_to_save.save_pretrained(output_encoder_dir)
                                    torch.save(args, os.path.join(output_encoder_dir, 'training_args.bin'))
                                    logger.info("Saving model checkpoint to %s", output_encoder_dir)
                                    save_solid = True
                                except:
                                    pass
                        else:
                            model_encoder_to_save.save_pretrained(output_encoder_dir)
                            torch.save(args, os.path.join(output_encoder_dir, 'training_args.bin'))
                            logger.info("Saving model checkpoint to %s", output_encoder_dir)

                        # Save decoder model checkpoint
                        output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}-{}'.format(global_step, model_vae.module.args.beta))

                        if not os.path.exists(output_decoder_dir):
                            os.makedirs(output_decoder_dir)

                        model_decoder_to_save = model_vae.module.decoder if hasattr(model_vae, 'module') else model_vae.decoder  # Take care of distributed/parallel training
                        if args.use_philly:
                            save_solid = False
                            while not save_solid:
                                try:
                                    model_decoder_to_save.save_pretrained(output_decoder_dir)
                                    torch.save(args, os.path.join(output_decoder_dir, 'training_args.bin'))
                                    logger.info("Saving model checkpoint to %s", output_decoder_dir)
                                    save_solid = True
                                except:
                                    pass
                        else:
                            model_decoder_to_save.save_pretrained(output_decoder_dir)
                            torch.save(args, os.path.join(output_decoder_dir, 'training_args.bin'))
                            logger.info("Saving model checkpoint to %s", output_decoder_dir)


                if args.max_steps > 0 and global_step > args.max_steps:
                    #epoch_iterator.close()
                    break
                
            if args.max_steps > 0 and global_step > args.max_steps:
                train_iterator.close()
                break


    # print(dict_token_length)
    # with open('wikipedia_stats.json', 'w') as fp:
    #     json.dump(dict_token_length, fp)

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
Exemplo n.º 3
0
def train(args, train_dataloader, model_vae, encoder_tokenizer, decoder_tokenizer, table_name):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    # train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    # train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)


    # model_encoder, model_decoder, model_connector = model_vae.encoder,  model_vae.decoder, model_vae.linear
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model_vae.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in model_vae.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)


    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model_vae, optimizer = amp.initialize(model_vae, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model_vae = torch.nn.DataParallel(model_vae, device_ids=range(args.n_gpu)).to(args.device)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model_vae = torch.nn.parallel.DistributedDataParallel(model_vae, device_ids=[args.local_rank],
                                                          output_device=args.local_rank,
                                                          find_unused_parameters=True)


    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", train_dataloader.num_examples)
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
                   args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0


    model_vae.zero_grad()
   
    # model_vae = model_vae.module if hasattr(model_vae, 'module') else model_vae  # Take care of distributed/parallel training   
    
    train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])

    n_iter = int(args.num_train_epochs) * len(train_dataloader)
    beta_t_list = frange_cycle_zero_linear(n_iter, start=0.0, stop=args.beta,  n_cycle=1, ratio_increase=args.ratio_increase, ratio_zero=args.ratio_zero)

    tmp_list = []
    set_seed(args)  # Added here for reproducibility (even between python 2 and 3)
    for epoch in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            input_ids_bert_ctx, input_ids_bert, input_ids_gpt, token_lengths = batch

            logger.info(f'Conxtext in Bert, Length {token_lengths[0]} ; Tokens: {input_ids_bert_ctx}')
            logger.info(f'Response in Bert, Length {token_lengths[1]} ; Tokens: {input_ids_bert}')
            logger.info(f'Response in GPT2, Length {token_lengths[2]} ; Tokens: {input_ids_gpt}')
            # TODO: write donw training scripts for dialog response generation


            if (step + 1) % args.gradient_accumulation_steps == 0:

                global_step += 1


            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break

            
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step