コード例 #1
0
ファイル: train.py プロジェクト: gregrolwes/ViT-pytorch
def train(args, model, device, acc_calculator):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        os.makedirs(args.output_dir, exist_ok=True)
        writer = SummaryWriter(log_dir=os.path.join("logs", args.name))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    # Prepare dataset
    train_loader, test_loader = get_loader(args)

    # Prepare optimizer and scheduler
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.learning_rate,
                                momentum=0.9,
                                weight_decay=args.weight_decay)
    t_total = args.num_steps
    if args.decay_type == "cosine":
        scheduler = WarmupCosineSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
    else:
        scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Total optimization steps = %d", args.num_steps)
    logger.info("  Instantaneous batch size per GPU = %d", args.train_batch_size)
    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
                args.train_batch_size * args.gradient_accumulation_steps)
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)

    model.zero_grad()
    set_seed(args)  # Added here for reproducibility (even between python 2 and 3)
    losses = AverageMeter()
    global_step, best_acc = 0, 0
    loss_fct = BatchAllLoss(margin=0.1)
    while True:
        model.train()
        epoch_iterator = tqdm(train_loader,
                              desc="Training (X / X Steps) (loss=X.X)",
                              bar_format="{l_bar}{r_bar}",
                              dynamic_ncols=True,
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            batch = tuple(t.to(device) for t in batch[:2])
            x, y = batch
            embeds = model(x)
            loss = loss_fct(embeds, y)
            losses.update(loss.item())

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            global_step += 1

            epoch_iterator.set_description(
                "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, t_total, losses.val)
            )
            if args.local_rank in [-1, 0]:
                writer.add_scalar("train/loss", scalar_value=losses.val, global_step=global_step)
                writer.add_scalar("train/lr", scalar_value=scheduler.get_lr()[0], global_step=global_step)
            if global_step % args.eval_every == 0 and args.local_rank in [-1, 0]:
                accuracy = valid(args, model, writer, test_loader, global_step, device, acc_calculator)
                if best_acc < accuracy:
                    save_model(args, model)
                    best_acc = accuracy
                model.train()

            if global_step % t_total == 0:
                break
        losses.reset()
        if global_step % t_total == 0:
            break

    if args.local_rank in [-1, 0]:
        writer.close()
    logger.info("Best Accuracy: \t%f" % best_acc)
    logger.info("End Training!")
コード例 #2
0
def train(args, model):
    if args.local_rank in [-1, 0]:
        os.makedirs(args.output_dir, exist_ok=True)
        writer = SummaryWriter(log_dir=os.path.join("logs", args.name))
        # writer = SummaryWriter(log_dir=os.path.join("logs", 'transformer'))
    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
    train_loader, test_loader = get_loader(args)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.learning_rate,
                                momentum=0.9,
                                weight_decay=args.weight_decay)
    t_total = args.num_steps

    if args.decay_type == "cosine":
        scheduler = WarmupCosineSchedule(optimizer,
                                         warmup_steps=args.warmup_steps,
                                         t_total=t_total)
    else:
        scheduler = WarmupLinearSchedule(optimizer,
                                         warmup_steps=args.warmup_steps,
                                         t_total=t_total)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Total optimization steps = %d", args.num_steps)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    model.zero_grad()
    set_seed(args)
    losses = AverageMeter()
    global_step, best_acc, best_losses = 0, 0, np.inf
    while True:
        model.train()
        # criterion = nn.CrossEntropyLoss()
        criterion = FocalLoss()
        # criterion = MyCrossEntropyLoss()
        # criterion = MyMseLoss(args)
        epoch_iterator = tqdm(train_loader,
                              desc="Training (X / X Steps) (loss=X.X)",
                              bar_format="{l_bar}{r_bar}",
                              dynamic_ncols=True,
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            batch = tuple(t.to(args.device) for t in batch)
            imgs, intensity, labels = batch
            outputs = model(imgs, intensity,
                            labels)  # output[0].size = (B, N, C)
            outputs = outputs.view(-1, args.class_number)
            # outputs = F.softmax(outputs, dim=-1)
            labels = labels.view(-1)
            loss = criterion(outputs, labels)
            loss.backward()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                losses.update(loss.item() * args.gradient_accumulation_steps)

                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)
                scheduler.step()
                optimizer.step()
                optimizer.zero_grad()
                global_step += 1

            epoch_iterator.set_description(
                "Training (%d / %d Steps) (loss=%2.5f)" %
                (global_step, t_total, loss.cpu()))
            if args.local_rank in [-1, 0]:
                writer.add_scalar("train/loss",
                                  scalar_value=losses.val,
                                  global_step=global_step)
                writer.add_scalar("train/lr",
                                  scalar_value=scheduler.get_lr()[0],
                                  global_step=global_step)
            if global_step % args.eval_every == 0:
                accuracy, eval_losses = valid(args, model, writer, test_loader,
                                              global_step)
                if eval_losses <= best_losses:
                    save_model(args, model)
                    best_losses = eval_losses
                model.train()

            if global_step % t_total == 0:
                break
        losses.reset()
        if global_step % t_total == 0:
            break
    if args.local_rank in [-1, 0]:
        writer.close()
    logger.info("Best Accuracy: \t%f" % best_acc)
    logger.info("End Training!")
コード例 #3
0
ファイル: train.py プロジェクト: yeongseo96/ViT-pytorch
def train(args, model):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        os.makedirs(args.output_dir, exist_ok=True)
        writer = SummaryWriter(log_dir=os.path.join("logs", args.name))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    # Prepare dataset
    train_loader, test_loader = get_loader(args)

    # Prepare optimizer and scheduler
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.learning_rate,
                                momentum=0.9,
                                weight_decay=args.weight_decay)
    t_total = args.num_steps
    if args.decay_type == "cosine":
        scheduler = WarmupCosineSchedule(optimizer,
                                         warmup_steps=args.warmup_steps,
                                         t_total=t_total)
    else:
        scheduler = WarmupLinearSchedule(optimizer,
                                         warmup_steps=args.warmup_steps,
                                         t_total=t_total)

    if args.fp16:
        model, optimizer = amp.initialize(models=model,
                                          optimizers=optimizer,
                                          opt_level=args.fp16_opt_level)
        amp._amp_state.loss_scalers[0]._loss_scale = 2**20

    # Distributed training
    if args.local_rank != -1:
        model = DDP(model,
                    message_size=250000000,
                    gradient_predivide_factor=get_world_size())

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Total optimization steps = %d", args.num_steps)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)

    model.zero_grad()
    set_seed(
        args)  # Added here for reproducibility (even between python 2 and 3)
    losses = AverageMeter()
    global_step, best_acc = 0, 0
    while True:
        model.train()
        epoch_iterator = tqdm(train_loader,
                              desc="Training (X / X Steps) (loss=X.X)",
                              bar_format="{l_bar}{r_bar}",
                              dynamic_ncols=True,
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            batch = tuple(t.to(args.device) for t in batch)
            x, y = batch
            loss = model(x, y)

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            if (step + 1) % args.gradient_accumulation_steps == 0:
                losses.update(loss.item() * args.gradient_accumulation_steps)
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                scheduler.step()
                optimizer.step()
                optimizer.zero_grad()
                global_step += 1

                epoch_iterator.set_description(
                    "Training (%d / %d Steps) (loss=%2.5f)" %
                    (global_step, t_total, losses.val))
                if args.local_rank in [-1, 0]:
                    writer.add_scalar("train/loss",
                                      scalar_value=losses.val,
                                      global_step=global_step)
                    writer.add_scalar("train/lr",
                                      scalar_value=scheduler.get_lr()[0],
                                      global_step=global_step)
                if global_step % args.eval_every == 0 and args.local_rank in [
                        -1, 0
                ]:
                    accuracy = valid(args, model, writer, test_loader,
                                     global_step)
                    if best_acc < accuracy:
                        save_model(args, model)
                        best_acc = accuracy
                    model.train()

                if global_step % t_total == 0:
                    break
        losses.reset()
        if global_step % t_total == 0:
            break

    if args.local_rank in [-1, 0]:
        writer.close()
    logger.info("Best Accuracy: \t%f" % best_acc)
    logger.info("End Training!")
コード例 #4
0
def train_engine(__C):
    # define network
    net = get_network(__C)
    net = net.cuda()

    __C.batch_size = __C.batch_size // __C.gradient_accumulation_steps

    # define dataloader
    train_loader = get_train_loader(__C)
    test_loader = get_test_loader(__C)

    # define optimizer and loss function
    if __C.label_smoothing:
        loss_function = LabelSmoothingCrossEntropy(__C.smoothing)
    else:
        loss_function = nn.CrossEntropyLoss()

    # define optimizer and training parameters
    if __C.no_bias_decay:
        params = split_weights(net)
    else:
        params = net.parameters()
    optimizer = optim.SGD(params, lr=__C.lr, momentum=0.9, weight_decay=5e-4)

    # define optimizer scheduler
    # len(train_loader) 就是一个epoch的steps数量
    warmup_steps = __C.warmup_steps
    total_steps = __C.num_steps
    # change epoch into steps
    for i in __C.milestones:
        i *= len(train_loader)
    if __C.decay_type == 'multi_step':
        train_scheduler = WarmupMultiStepSchedule(__C,
                                                  optimizer,
                                                  warmup_steps=warmup_steps,
                                                  t_total=total_steps)
    elif __C.decay_type == 'cosine':
        train_scheduler = WarmupCosineSchedule(optimizer,
                                               warmup_steps=warmup_steps,
                                               t_total=total_steps)
    elif __C.decay_type == 'linear':
        train_scheduler = WarmupLinearSchedule(optimizer,
                                               warmup_steps=warmup_steps,
                                               t_total=total_steps)

    # define tensorboard writer
    writer = SummaryWriter(
        log_dir=os.path.join(__C.tensorboard_log_dir, __C.model, __C.version))

    # define model save dir
    checkpoint_path = os.path.join(__C.ckpts_dir, __C.model, __C.version)
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    checkpoint_path = os.path.join(checkpoint_path,
                                   '{net}-{global_step}-{type}.pth')

    # define log save dir
    log_path = os.path.join(__C.result_log_dir, __C.model)
    if not os.path.exists(log_path):
        os.makedirs(log_path)
    log_path = os.path.join(log_path, __C.version + '.txt')

    # write the hyper parameters to log
    logfile = open(log_path, 'a+')
    logfile.write(str(__C))
    logfile.close()

    # Train!
    logger.info("  ***** Running training *****")
    logger.info("  Total optimization steps = %d", __C.num_steps)
    logger.info("  Instantaneous batch size per GPU = %d", __C.batch_size)
    logger.info("  Gradient Accumulation steps = %d",
                __C.gradient_accumulation_steps)

    net.zero_grad()
    losses = AverageMeter()
    global_step, best_acc = 0, 0
    while True:
        net.train()
        epoch_iterator = tqdm(train_loader,
                              desc="Training (X / X Steps) (loss=X.X)",
                              bar_format="{l_bar}{r_bar}",
                              dynamic_ncols=True)
        for step, (images, labels) in enumerate(train_loader):
            images = images.cuda()
            labels = labels.cuda()
            train_outputs = net(images)
            loss = loss_function(train_outputs, labels)

            if __C.gradient_accumulation_steps > 1:
                loss = loss / __C.gradient_accumulation_steps
            else:
                loss.backward()

            if (step + 1) % __C.gradient_accumulation_steps == 0:
                losses.update(loss.item() * __C.gradient_accumulation_steps)
                torch.nn.utils.clip_grad_norm_(net.parameters(),
                                               __C.max_grad_norm)
                train_scheduler.step()
                optimizer.step()
                optimizer.zero_grad()
                global_step += 1

                epoch_iterator.set_description(
                    "Training (%d / %d Steps) (loss=%2.5f)" %
                    (global_step, total_steps, losses.val))

                writer.add_scalar("[Step] Train/loss",
                                  scalar_value=losses.val,
                                  global_step=global_step)
                writer.add_scalar("[Step] Train/lr",
                                  scalar_value=train_scheduler.get_lr()[0],
                                  global_step=global_step)

                if global_step % __C.eval_every == 0:
                    accuracy = valid(__C,
                                     model=net,
                                     writer=writer,
                                     test_loader=test_loader,
                                     global_step=global_step,
                                     loss_function=loss_function)
                    if best_acc < accuracy:
                        torch.save(
                            net.state_dict(),
                            checkpoint_path.format(net=__C.model,
                                                   global_step=global_step,
                                                   type='best'))
                        best_acc = accuracy
                    net.train()

                if global_step % total_steps == 0:
                    break
        losses.reset()
        if global_step % total_steps == 0:
            break

    writer.close()
    logger.info("Best Accuracy: \t%f" % best_acc)
    logger.info("End Training!")