Exemplo n.º 1
0
def train(epoch, trainloader, steps_per_val, base_lr,
          total_epochs, optimizer, model, 
          adjust_learning_rate, print_freq, 
          image_freq, image_outdir, local_rank, sub_losses):    
    # Training
    model.train()
    
    batch_time = AverageMeter()
    ave_loss = AverageMeter()
    tic = time.time()
    cur_iters = epoch*steps_per_val
    for i_iter, dp in enumerate(trainloader):
        def handle_batch():
            a, fg, bg = dp      # [B, 3, 3 or 1, H, W]
            #print (a.shape)
            out = model(a, fg, bg)
            L_alpha = out[0].mean()
            L_comp = out[1].mean()
            L_grad = out[2].mean()
            vis_alpha = L_alpha.detach().item()
            vis_comp = L_comp.detach().item()
            vis_grad = L_grad.detach().item()
            #L_temp = out[3].mean()
            #loss['L_total'] = 0.5 * loss['L_alpha'] + 0.5 * loss['L_comp'] + loss['L_grad'] + 0.5 * loss['L_temp']
            #loss['L_total'] = loss['L_alpha'] + loss['L_comp'] + loss['L_grad'] + loss['L_temp']
            loss = L_alpha + L_comp + L_grad

            model.zero_grad()
            loss.backward()
            optimizer.step()
            return loss.detach(), vis_alpha, vis_comp, vis_grad, out[3:]

        loss, vis_alpha, vis_comp, vis_grad, vis_out = handle_batch()

        reduced_loss = reduce_tensor(loss)
        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()

        # update average loss
        ave_loss.update(reduced_loss.item())
        torch_barrier()

        adjust_learning_rate(optimizer,
                            base_lr,
                            total_epochs * steps_per_val,
                            i_iter+cur_iters)

        if i_iter % print_freq == 0 and local_rank <= 0:
            msg = 'Iter:[{}/{}], Time: {:.2f}, '.format(\
                i_iter+cur_iters, total_epochs * steps_per_val, batch_time.average())
            msg += 'lr: {}, Avg. Loss: {:.6f} | Current: Loss: {:.6f}, '.format(
                [x['lr'] for x in optimizer.param_groups],
                ave_loss.average(), ave_loss.value())
            msg += '{}: {:.4f} {}: {:.4f} {}: {:.4f}'.format(
                sub_losses[0], vis_alpha, 
                sub_losses[1], vis_comp,
                sub_losses[2], vis_grad)
            logging.info(msg)
        
        if i_iter % image_freq == 0 and local_rank <= 0:
            write_image(image_outdir, vis_out, i_iter+cur_iters)