Exemplo n.º 1
0
def train(epoch):
    net.train()
    train_loss = 0
    correct    = 0
    total      = 0
    optimizer  = optim.SGD(net.parameters(), lr = lr_schedule(lr, epoch), momentum = 0.9, weight_decay = 5e-4)
    
    print('Training Epoch: #%d, LR: %.4f'%(epoch, lr_schedule(lr, epoch)))
    for idx, (inputs, labels) in enumerate(train_loader):
        if is_use_cuda:
            inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        writer.add_scalar('Train/Loss', loss.item(), epoch* 50000 + batch_size * (idx + 1)  )
        train_loss += loss.item()
        _, predict = torch.max(outputs, 1)
        total += labels.size(0)
        correct += predict.eq(labels).cpu().sum().double()
        
        sys.stdout.write('\r')
        sys.stdout.write('[%s] Training Epoch [%d/%d] Iter[%d/%d]\t\tLoss: %.4f Acc@1: %.3f'
                        % (time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())),
                           epoch, num_epochs, idx, len(train_dataset) // batch_size, 
                          train_loss / (batch_size * (idx + 1)), correct / total))
        sys.stdout.flush()
    writer.add_scalar('Train/Accuracy', correct / total, epoch )
Exemplo n.º 2
0
def train(epoch):
    # import pydevd
    #
    # pydevd.settrace(suspend=False, trace_only_current_thread=True)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    optimizer = optim.SGD(net.parameters(), lr=lr_schedule(lr, epoch), momentum=0.9, weight_decay=5e-4)

    # log_message = 'TrainingEpoch: {:d} | LR: {:.4f}'.format(epoch, lr_schedule(lr, epoch))
    # logger.info(log_message)
    # net.module.dt = list()
    # net.module.dt = list()
    # net.module.forward_t = list()
    # net.module.backward_t = list()
    net.module.nbe = 0
    net.module.nfe = 0
    batches = []
    for idx, (inputs, labels) in enumerate(train_loader):
        if is_use_cuda:
            batches.append((inputs.to(device), labels.to(device)))
        else:
            batches.append((inputs, labels))
    for idx, (inputs, labels) in enumerate(batches):
        # if is_use_cuda:
        #     inputs, labels = inputs.to(device), labels.to(device)
        epoch_time = time.time()
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        # f_t = net.module.f_t
        # z_t = net.module.z_t
        # forward_dt = net.module.dt
        # net.module.dt = list()
        loss.backward()
        optimizer.step()

        # backward_dt = net.module.dt
        # net.module.dt = list()

        #         writer.add_scalar('Train/Loss', loss.item(), epoch * 50000 + batch_size * (idx + 1))
        train_loss += loss.item()
        _, predict = torch.max(outputs, 1)
        total += labels.size(0)
        correct += predict.eq(labels).cpu().sum().double()

        # log_message = 'TrainingEpoch [{:d}/{:d}] | Iter[{:d}/{:d}] | Loss: {:.8f} | Acc@1: {:.4f} | FdT {:} | ' \
        #               'BdT {:} | TotalTime {:.4f} | tF {:} | tB {:} | NFE {:} | NBE {:} | PeakMemory: {:d} | ' \
        #               'z_t {:} | f_t {:}'.format(
        #     epoch, num_epochs,
        #     idx, len(train_dataset) // batch_size,
        #                             train_loss / (batch_size * (idx + 1)), correct / total,
        #     '[]', '[]', time.time() - epoch_time,
        #     str(net.module.forward_t), str(net.module.backward_t),
        #     str(net.module.nbe),
        #     str(net.module.nfe),
        #     # TODO: first, nfe and nbe have to be swaped here. Second, we have to take nfe before backward call.
        #     torch.cuda.max_memory_allocated(device),
        #     str(z_t),
        #     str(f_t),
        # )
        if args.wandb_name:
            wandb.log({'epoch': epoch,
                       'iteration': epoch * len(batches) + idx,
                       'train_loss': train_loss / (batch_size * (idx + 1)),
                       'train_accuracy': correct / total,
                       'time_meter': time.time() - epoch_time,
                       'nbe': sum(net.module.nbe.values()),
                       'nfe': sum(net.module.nfe.values()),
                       'PeakMemory': torch.cuda.max_memory_allocated(device)})

        # if idx % args.log_every == 0:
            # logger.info(log_message)
        #             torch.save({
        #                 "args": args,
        #                 "state_dict": net.state_dict() if torch.cuda.is_available() else net.state_dict(),
        #                 "optim_state_dict": optimizer.state_dict(),
        #             }, os.path.join(args.save, "checkpt_{:d}_{:d}.pth".format(idx, epoch)))

        # net.module.forward_t = list()
        # net.module.backward_t = list()
        net.module.nbe = 0
        net.module.nfe = 0

    # log_message = 'Train/Accuracy | Acc@1: {:.3f} | Epoch {:d}'.format(correct / total, epoch)
    # logger.info(log_message)

    if (epoch - 1) % args.save_every == 0:
        torch.save({
            "args": args,
            "state_dict": net.state_dict() if torch.cuda.is_available() else net.state_dict(),
            "optim_state_dict": optimizer.state_dict(),
        }, os.path.join(args.save, "checkpt_{}_{:d}.pth".format(checkpoint_id, epoch)))

    torch.save({
        "args": args,
        "state_dict": net.state_dict() if torch.cuda.is_available() else net.state_dict(),
        "optim_state_dict": optimizer.state_dict(),
    }, os.path.join(args.save, "checkpt_{}.pth".format(checkpoint_id)))