Ejemplo n.º 1
0
    def _train_epoch(sess, epoch, step, smry_writer):
        random.shuffle(train_data)
        train_iter = data.iterator.pool(
            train_data,
            config_data.batch_size,
            key=lambda x: (len(x[0]), len(x[1])),
            batch_size_fn=utils.batch_size_fn,
            random_shuffler=data.iterator.RandomShuffler())

        for _, train_batch in enumerate(train_iter):
            in_arrays = data_utils.seq2seq_pad_concat_convert(train_batch)
            feed_dict = {
                encoder_input: in_arrays[0],
                decoder_input: in_arrays[1],
                labels: in_arrays[2],
                learning_rate: utils.get_lr(step, config_model.lr)
            }
            fetches = {
                'step': global_step,
                'train_op': train_op,
                'smry': summary_merged,
                'loss': mle_loss,
            }

            fetches_ = sess.run(fetches, feed_dict=feed_dict)

            step, loss = fetches_['step'], fetches_['loss']
            if step and step % config_data.display_steps == 0:
                logger.info('step: %d, loss: %.4f', step, loss)
                print('step: %d, loss: %.4f' % (step, loss))
                smry_writer.add_summary(fetches_['smry'], global_step=step)

            if step and step % config_data.eval_steps == 0:
                _eval_epoch(sess, epoch, mode='eval')
        return step
Ejemplo n.º 2
0
def train(model, optimizer, loader, scheduler, criterion, ema, device, writer):
    model.train()

    progress_bar = tqdm(enumerate(loader), total=len(loader), leave=False)
    for step, batch in progress_bar:
        x, labels = batch['img'], batch['annotation']
        gt_labels, gt_boxes = labels[:, :, 4], labels[:, :, :4]
        batch_size = x.shape[0]

        wrapper = DetectionTrainWrapper(model, device, criterion)
        loss, cls_loss, box_loss = wrapper(x, gt_labels, gt_boxes)

        values = [v.data.item() for v in [loss, cls_loss, box_loss]]

        progress_bar.set_description(
            "all:{0:.2f} | cls:{1:.2f} | box:{2:.2f}".format(
                values[0], values[1], values[2]))

        if is_valid_number(loss.data.item()):
            loss.backward()

            writer.add_scalar('Train/overall_loss', values[0], writer.train_step)
            writer.add_scalar('Train/class_loss', values[1], writer.train_step)
            writer.add_scalar('Train/box_loss', values[2], writer.train_step)
            writer.add_scalar('Train/gradnorm', get_gradnorm(optimizer), writer.train_step)
            writer.add_scalar('Train/lr', get_lr(optimizer), writer.train_step)
            writer.add_scalar(f"Train/gpu memory", torch.cuda.memory_allocated(device), writer.train_step)

            writer.train_step += 1
            writer.flush()

            clip_grad_norm_(model.parameters(), cfg.CLIP_GRADIENTS_NORM)
            optimizer.step()
            optimizer.zero_grad()

            ema(model, step // batch_size)

            scheduler.step()

    return model, optimizer, scheduler, writer
    def _train_epoch(sess, epoch, step, smry_writer):
        print('Start epoch %d' % epoch)
        data_iterator.restart_dataset(sess, 'train')

        fetches = {
            'train_op': train_op,
            'loss': mle_loss,
            'step': global_step,
            'smry': summary_merged
        }

        while True:
            try:
                feed_dict = {
                    data_iterator.handle:
                    data_iterator.get_handle(sess, 'train'),
                    tx.global_mode(): tf.estimator.ModeKeys.TRAIN,
                    learning_rate: utils.get_lr(step, config_model)
                }

                fetches_ = sess.run(fetches, feed_dict)
                step, loss = fetches_['step'], fetches_['loss']

                # Display every display_steps
                display_steps = config_data.display_steps
                if display_steps > 0 and step % display_steps == 0:
                    print(
                        '[%s] step: %d, loss: %.4f' %
                        (strftime("%Y-%m-%d %H:%M:%S", gmtime()), step, loss))
                    smry_writer.add_summary(fetches_['smry'], global_step=step)

                # Eval every eval_steps
                eval_steps = config_data.eval_steps
                if eval_steps > 0 and step % eval_steps == 0 and step > 0:
                    _eval_epoch(sess, epoch, 'eval')

            except tf.errors.OutOfRangeError:
                break

        return step
Ejemplo n.º 4
0
def train(model, optimizer, loader, scheduler, criterion, ema, device, writer):
    model.train()

    pbar = tqdm(enumerate(loader), total=len(loader), leave=False)
    for step, batch in pbar:

        batch_size = batch.shape[0]
        x, labels = batch
        cls_output, box_output = model(x)

        loss, cls_loss, box_loss = criterion(cls_output, box_output, labels)
        values = [v.data.item() for v in [loss, cls_loss, box_loss]]

        pbar.set_description(
            "all:{.2f} | cls:{.2f} | box:{.2f}".format(
                values[0], values[1], values[2])
        )

        if is_valid_number(loss.data.item()):
            loss.backward()

            writer.add_scalar('Train/overall_loss', values[0], writer.train_step)
            writer.add_scalar('Train/class_loss', values[1], writer.train_step)
            writer.add_scalar('Train/box_loss', values[2], writer.train_step)
            writer.add_scalar('Train/gradnorm', get_gradnorm(optimizer), writer.train_step)
            writer.add_scalar('Train/lr', get_lr(optimizer), writer.train_step)
            writer.add_scalar(f"Train/gpu memory", torch.cuda.memory_allocated(device), writer.train_step)

            writer.train_step += 1

            clip_grad_norm_(model.parameters(), cfg.CLIP_GRADIENTS_NORM)
            optimizer.step()
            optimizer.zero_grad()

            ema(model, step // batch_size)

            scheduler.step()

    return model, optimizer, scheduler, writer
Ejemplo n.º 5
0
def fit_one_epoch(model_train,
                  model,
                  ssd_loss,
                  loss_history,
                  eval_callback,
                  optimizer,
                  epoch,
                  epoch_step,
                  epoch_step_val,
                  gen,
                  gen_val,
                  Epoch,
                  cuda,
                  fp16,
                  scaler,
                  save_period,
                  save_dir,
                  local_rank=0):
    total_loss = 0
    val_loss = 0

    if local_rank == 0:
        print('Start Train')
        pbar = tqdm(total=epoch_step,
                    desc=f'Epoch {epoch + 1}/{Epoch}',
                    postfix=dict,
                    mininterval=0.3)
    model_train.train()
    for iteration, batch in enumerate(gen):
        if iteration >= epoch_step:
            break
        images, targets = batch[0], batch[1]
        with torch.no_grad():
            if cuda:
                images = images.cuda(local_rank)
                targets = targets.cuda(local_rank)
        if not fp16:
            #----------------------#
            #   前向传播
            #----------------------#
            out = model_train(images)
            #----------------------#
            #   清零梯度
            #----------------------#
            optimizer.zero_grad()
            #----------------------#
            #   计算损失
            #----------------------#
            loss = ssd_loss.forward(targets, out)
            #----------------------#
            #   反向传播
            #----------------------#
            loss.backward()
            optimizer.step()
        else:
            from torch.cuda.amp import autocast
            with autocast():
                #----------------------#
                #   前向传播
                #----------------------#
                out = model_train(images)
                #----------------------#
                #   清零梯度
                #----------------------#
                optimizer.zero_grad()
                #----------------------#
                #   计算损失
                #----------------------#
                loss = ssd_loss.forward(targets, out)

            #----------------------#
            #   反向传播
            #----------------------#
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        total_loss += loss.item()

        if local_rank == 0:
            pbar.set_postfix(
                **{
                    'total_loss': total_loss / (iteration + 1),
                    'lr': get_lr(optimizer)
                })
            pbar.update(1)

    if local_rank == 0:
        pbar.close()
        print('Finish Train')
        print('Start Validation')
        pbar = tqdm(total=epoch_step_val,
                    desc=f'Epoch {epoch + 1}/{Epoch}',
                    postfix=dict,
                    mininterval=0.3)

    model_train.eval()
    for iteration, batch in enumerate(gen_val):
        if iteration >= epoch_step_val:
            break
        images, targets = batch[0], batch[1]
        with torch.no_grad():
            if cuda:
                images = images.cuda(local_rank)
                targets = targets.cuda(local_rank)

            out = model_train(images)
            optimizer.zero_grad()
            loss = ssd_loss.forward(targets, out)
            val_loss += loss.item()

            if local_rank == 0:
                pbar.set_postfix(
                    **{
                        'val_loss': val_loss / (iteration + 1),
                        'lr': get_lr(optimizer)
                    })
                pbar.update(1)

    if local_rank == 0:
        pbar.close()
        print('Finish Validation')
        loss_history.append_loss(epoch + 1, total_loss / epoch_step,
                                 val_loss / epoch_step_val)
        eval_callback.on_epoch_end(epoch + 1, model_train)
        print('Epoch:' + str(epoch + 1) + '/' + str(Epoch))
        print('Total Loss: %.3f || Val Loss: %.3f ' %
              (total_loss / epoch_step, val_loss / epoch_step_val))

        #-----------------------------------------------#
        #   保存权值
        #-----------------------------------------------#
        if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
            torch.save(
                model.state_dict(),
                os.path.join(
                    save_dir, "ep%03d-loss%.3f-val_loss%.3f.pth" %
                    (epoch + 1, total_loss / epoch_step,
                     val_loss / epoch_step_val)))

        if len(loss_history.val_loss) <= 1 or (
                val_loss / epoch_step_val) <= min(loss_history.val_loss):
            print('Save best model to best_epoch_weights.pth')
            torch.save(model.state_dict(),
                       os.path.join(save_dir, "best_epoch_weights.pth"))

        torch.save(model.state_dict(),
                   os.path.join(save_dir, "last_epoch_weights.pth"))
Ejemplo n.º 6
0
                    mname = os.path.join(config['log_dir'], '{}.pt'.format(monitor.counter + 1))
                    print('saving model', mname)
                    torch.save({
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        'counter': monitor.counter,
                        'epoch': epoch,
                        }, mname)
                    model.train()
            if (monitor.counter + 1) % config['backup_rate'] == 0:
                with torch.no_grad():
                    model.eval()
                    mname = os.path.join(config['log_dir'], 'latest.pt')
                    print('saving model', mname)
                    torch.save({
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        'counter': monitor.counter,
                        'epoch': epoch,
                        }, mname)
                    model.train()

            valid_metric = monitor.step(loss, dict_loss)
            if valid_metric is not None:
                scheduler.step(valid_metric)
                monitor.writer.add_scalar('val/learning_rate', get_lr(optimizer), monitor.counter)
            if monitor.counter >= config['max_iterations']:
                break
def fit_one_epoch(model_train, model, yolo_loss, loss_history, optimizer,
                  epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda,
                  save_period):
    loss = 0
    val_loss = 0

    model_train.train()
    print('Start Train')
    with tqdm(total=epoch_step,
              desc=f'Epoch {epoch + 1}/{Epoch}',
              postfix=dict,
              mininterval=0.3) as pbar:
        for iteration, batch in enumerate(gen):
            if iteration >= epoch_step:
                break

            images, targets = batch[0], batch[1]
            with torch.no_grad():
                if cuda:
                    images = torch.from_numpy(images).type(
                        torch.FloatTensor).cuda()
                    targets = [
                        torch.from_numpy(ann).type(torch.FloatTensor).cuda()
                        for ann in targets
                    ]
                else:
                    images = torch.from_numpy(images).type(torch.FloatTensor)
                    targets = [
                        torch.from_numpy(ann).type(torch.FloatTensor)
                        for ann in targets
                    ]
            #----------------------#
            #   清零梯度
            #----------------------#
            optimizer.zero_grad()
            #----------------------#
            #   前向传播
            #----------------------#
            outputs = model_train(images)

            loss_value_all = 0
            #----------------------#
            #   计算损失
            #----------------------#
            for l in range(len(outputs)):
                loss_item = yolo_loss(l, outputs[l], targets)
                loss_value_all += loss_item
            loss_value = loss_value_all

            #----------------------#
            #   反向传播
            #----------------------#
            loss_value.backward()
            optimizer.step()

            loss += loss_value.item()

            pbar.set_postfix(**{
                'loss': loss / (iteration + 1),
                'lr': get_lr(optimizer)
            })
            pbar.update(1)

    print('Finish Train')

    model_train.eval()
    print('Start Validation')
    with tqdm(total=epoch_step_val,
              desc=f'Epoch {epoch + 1}/{Epoch}',
              postfix=dict,
              mininterval=0.3) as pbar:
        for iteration, batch in enumerate(gen_val):
            if iteration >= epoch_step_val:
                break
            images, targets = batch[0], batch[1]
            with torch.no_grad():
                if cuda:
                    images = torch.from_numpy(images).type(
                        torch.FloatTensor).cuda()
                    targets = [
                        torch.from_numpy(ann).type(torch.FloatTensor).cuda()
                        for ann in targets
                    ]
                else:
                    images = torch.from_numpy(images).type(torch.FloatTensor)
                    targets = [
                        torch.from_numpy(ann).type(torch.FloatTensor)
                        for ann in targets
                    ]
                #----------------------#
                #   清零梯度
                #----------------------#
                optimizer.zero_grad()
                #----------------------#
                #   前向传播
                #----------------------#
                outputs = model_train(images)

                loss_value_all = 0
                #----------------------#
                #   计算损失
                #----------------------#
                for l in range(len(outputs)):
                    loss_item = yolo_loss(l, outputs[l], targets)
                    loss_value_all += loss_item
                loss_value = loss_value_all

            val_loss += loss_value.item()
            pbar.set_postfix(**{'val_loss': val_loss / (iteration + 1)})
            pbar.update(1)

    print('Finish Validation')

    loss_history.append_loss(epoch + 1, loss / epoch_step,
                             val_loss / epoch_step_val)
    print('Epoch:' + str(epoch + 1) + '/' + str(Epoch))
    print('Total Loss: %.3f || Val Loss: %.3f ' %
          (loss / epoch_step, val_loss / epoch_step_val))
    if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
        torch.save(
            model.state_dict(), 'logs/ep%03d-loss%.3f-val_loss%.3f.pth' %
            (epoch + 1, loss / epoch_step, val_loss / epoch_step_val))
Ejemplo n.º 8
0
def fit_one_epoch(model_train, model, loss_history, optimizer, epoch,
                  epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda,
                  dice_loss, focal_loss, cls_weights, aux_branch, num_classes):
    total_loss = 0
    total_f_score = 0

    val_loss = 0
    val_f_score = 0

    model_train.train()
    print('Start Train')
    with tqdm(total=epoch_step,
              desc=f'Epoch {epoch + 1}/{Epoch}',
              postfix=dict,
              mininterval=0.3) as pbar:
        for iteration, batch in enumerate(gen):
            if iteration >= epoch_step:
                break
            imgs, pngs, labels = batch

            with torch.no_grad():
                imgs = torch.from_numpy(imgs).type(torch.FloatTensor)
                pngs = torch.from_numpy(pngs).long()
                labels = torch.from_numpy(labels).type(torch.FloatTensor)
                weights = torch.from_numpy(cls_weights)
                if cuda:
                    imgs = imgs.cuda()
                    pngs = pngs.cuda()
                    labels = labels.cuda()
                    weights = weights.cuda()

            optimizer.zero_grad()
            if aux_branch:
                aux_outputs, outputs = model_train(imgs)
                if focal_loss:
                    aux_loss = Focal_Loss(outputs,
                                          pngs,
                                          weights,
                                          num_classes=num_classes)
                    main_loss = Focal_Loss(outputs,
                                           pngs,
                                           weights,
                                           num_classes=num_classes)
                else:
                    aux_loss = CE_Loss(aux_outputs,
                                       pngs,
                                       weights,
                                       num_classes=num_classes)
                    main_loss = CE_Loss(outputs,
                                        pngs,
                                        weights,
                                        num_classes=num_classes)
                loss = aux_loss + main_loss
                if dice_loss:
                    aux_dice = Dice_loss(aux_outputs, labels)
                    main_dice = Dice_loss(outputs, labels)
                    loss = loss + aux_dice + main_dice
            else:
                outputs = model_train(imgs)
                if focal_loss:
                    loss = Focal_Loss(outputs,
                                      pngs,
                                      weights,
                                      num_classes=num_classes)
                else:
                    loss = CE_Loss(outputs,
                                   pngs,
                                   weights,
                                   num_classes=num_classes)

                if dice_loss:
                    main_dice = Dice_loss(outputs, labels)
                    loss = loss + main_dice

            with torch.no_grad():
                #-------------------------------#
                #   计算f_score
                #-------------------------------#
                _f_score = f_score(outputs, labels)

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_f_score += _f_score.item()

            pbar.set_postfix(
                **{
                    'total_loss': total_loss / (iteration + 1),
                    'f_score': total_f_score / (iteration + 1),
                    'lr': get_lr(optimizer)
                })
            pbar.update(1)

    print('Finish Train')

    model_train.eval()
    print('Start Validation')
    with tqdm(total=epoch_step_val,
              desc=f'Epoch {epoch + 1}/{Epoch}',
              postfix=dict,
              mininterval=0.3) as pbar:
        for iteration, batch in enumerate(gen_val):
            if iteration >= epoch_step_val:
                break
            imgs, pngs, labels = batch
            with torch.no_grad():
                imgs = torch.from_numpy(imgs).type(torch.FloatTensor)
                pngs = torch.from_numpy(pngs).long()
                labels = torch.from_numpy(labels).type(torch.FloatTensor)
                weights = torch.from_numpy(cls_weights)
                if cuda:
                    imgs = imgs.cuda()
                    pngs = pngs.cuda()
                    labels = labels.cuda()
                    weights = weights.cuda()
                #-------------------------------#
                #   判断是否使用辅助分支
                #-------------------------------#
                if aux_branch:
                    aux_outputs, outputs = model_train(imgs)
                    if focal_loss:
                        aux_loss = Focal_Loss(outputs,
                                              pngs,
                                              weights,
                                              num_classes=num_classes)
                        main_loss = Focal_Loss(outputs,
                                               pngs,
                                               weights,
                                               num_classes=num_classes)
                    else:
                        aux_loss = CE_Loss(aux_outputs,
                                           pngs,
                                           weights,
                                           num_classes=num_classes)
                        main_loss = CE_Loss(outputs,
                                            pngs,
                                            weights,
                                            num_classes=num_classes)
                    loss = aux_loss + main_loss
                    if dice_loss:
                        aux_dice = Dice_loss(aux_outputs, labels)
                        main_dice = Dice_loss(outputs, labels)
                        loss = loss + aux_dice + main_dice
                else:
                    outputs = model_train(imgs)
                    if focal_loss:
                        loss = Focal_Loss(outputs,
                                          pngs,
                                          weights,
                                          num_classes=num_classes)
                    else:
                        loss = CE_Loss(outputs,
                                       pngs,
                                       weights,
                                       num_classes=num_classes)

                    if dice_loss:
                        main_dice = Dice_loss(outputs, labels)
                        loss = loss + main_dice
                #-------------------------------#
                #   计算f_score
                #-------------------------------#
                _f_score = f_score(outputs, labels)

                val_loss += loss.item()
                val_f_score += _f_score.item()

            pbar.set_postfix(
                **{
                    'total_loss': val_loss / (iteration + 1),
                    'f_score': val_f_score / (iteration + 1),
                    'lr': get_lr(optimizer)
                })
            pbar.update(1)

    loss_history.append_loss(total_loss / epoch_step,
                             val_loss / epoch_step_val)
    print('Finish Validation')
    print('Epoch:' + str(epoch + 1) + '/' + str(Epoch))
    print('Total Loss: %.3f || Val Loss: %.3f ' %
          (total_loss / epoch_step, val_loss / epoch_step_val))
    torch.save(
        model.state_dict(), 'logs/ep%03d-loss%.3f-val_loss%.3f.pth' %
        ((epoch + 1), total_loss / epoch_step, val_loss / epoch_step_val))
Ejemplo n.º 9
0
def main_teacher(args):

    print(("Process {}, running on {}: starting {}").format(
        os.getpid(), os.name, time.asctime))

    print("Training with Augmentation: ", args.augmentation)
    print("Training with Cutout: ", args.cutout)
    print("Training with Mixup: ", args.mixup)
    print("Training with CutMix: ", args.cutmix)

    process_num = round(time.time())
    dir_name = args.name + '_' + str(process_num) + str(args.dataset)
    tb_path = "distillation_experiments/logs/%s/" % (dir_name)
    pprint(args.__dict__)

    print(dir_name)

    writer = SummaryWriter(tb_path)

    use_gpu = args.gpu
    if not torch.cuda.is_available():
        use_gpu = False

    # Load Models
    model = model_fetch.fetch_teacher(args.teacher_model)

    if use_gpu:
        cudnn.benchmark = True
        model = model.cuda()

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
    ])
    train_transform = val_transform = transform

    if args.augmentation:
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(
            ),  # randomly flip image horizontally
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.247, 0.243, 0.261))
        ])

    if args.cutout:
        train_transform.transforms.append(
            cutout.Cutout(n_holes=args.n_holes, length=args.length_holes))

    train_loader = dataloader.fetch_dataloader("train", train_transform,
                                               args.dataset, args.batch_size)
    test_loader = dataloader.fetch_dataloader("test", val_transform,
                                              args.dataset, args.batch_size)

    params = [p for p in model.parameters() if p.requires_grad]

    optimizer = torch.optim.SGD(params,
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=args.decay)
    loss_fn = utils.loss_fn
    acc_fn = utils.accuracy

    start_epoch, best_loss = utils.load_checkpoint(model, args.resume)
    epoch = start_epoch
    while epoch <= int(args.n_epochs):
        print("=" * 50)
        utils.adjust_learning_rate(args.lr, optimizer, epoch, args.lr_decay)
        print(("Epoch {} Training Starting").format(epoch))
        print("Learning Rate : ", utils.get_lr(optimizer))

        train_loss = train.train(model, optimizer, loss_fn, acc_fn,
                                 train_loader, use_gpu, epoch, writer,
                                 args.mixup, args.alpha, args.cutmix,
                                 args.cutmix_prob, args.cutmix_beta)

        val_loss = train.validate(model, loss_fn, acc_fn, test_loader, use_gpu,
                                  epoch, writer)

        print("-" * 50)
        print(("Epoch {}, Training-Loss: {}, Validation-Loss: {}").format(
            epoch, train_loss, val_loss))
        print("=" * 50)

        curr_state = {
            "epoch": epoch,
            "best_loss": min(best_loss, val_loss),
            "model": model.state_dict()
        }

        # # Use only if model to be saved at each epoch
        # filename = 'epoch_' + str(epoch) + '_checkpoint.pth.tar'

        utils.save_checkpoint(
            state=curr_state,
            is_best=bool(val_loss < best_loss),
            dir_name=dir_name,
            # filename=filename
        )

        if val_loss < best_loss:
            best_loss = val_loss
        epoch += 1
        writer.add_scalar('data/learning_rate', utils.get_lr(optimizer), epoch)
Ejemplo n.º 10
0
def main_kd(args):

    print(("Process {}, running on {}: starting {}").format(
        os.getpid(), os.name, time.asctime))

    print("Student Training Underway!")
    print("Temperature: ", args.temperature)
    print("Relative Loss Weights: ", args.gamma)
    process_num = round(time.time())

    model_name = args.name + '_temp' + \
        str(args.temperature) + '_gamma' + str(args.gamma)
    dir_name = model_name + '_' + str(process_num)

    tb_path = "distillation_experiments/logs/%s/" % (dir_name)

    writer = SummaryWriter(tb_path)

    print("Arguments for model: ", model_name)
    pprint(args.__dict__)

    use_gpu = args.gpu
    if not torch.cuda.is_available():
        use_gpu = False

    # Load Models
    teacher_model = model_fetch.fetch_teacher(args.teacher_model)
    student_model = model_fetch.fetch_student(args.student_model)

    if use_gpu:
        cudnn.benchmark = True
        teacher_model = teacher_model.cuda()
        student_model = student_model.cuda()

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
    ])
    train_transform = val_transform = transform

    if False:
        if args.augmentation:
            train_transform = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(
                ),  # randomly flip image horizontally
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.247, 0.243, 0.261))
            ])

        if args.cutout:
            train_transform.transforms.append(
                cutout.Cutout(n_holes=args.n_holes, length=args.length_holes))

    train_loader = dataloader.fetch_dataloader("train", train_transform,
                                               args.batch_size)
    test_loader = dataloader.fetch_dataloader("test", val_transform,
                                              args.batch_size)

    params = [p for p in student_model.parameters() if p.requires_grad]

    optimizer = torch.optim.SGD(params, lr=args.lr, momentum=0.99)

    loss_fn = utils.kd_loss_fn
    simple_loss_fn = utils.loss_fn

    teacher_epoch, teacher_loss = utils.load_checkpoint(
        teacher_model, args.teacher_path)
    start_epoch, best_loss = utils.load_checkpoint(student_model, args.resume)

    print("Models Loaded!")
    print(student_model)

    epoch = start_epoch
    while epoch <= int(args.n_epochs):
        print("=" * 50)
        utils.adjust_learning_rate(args.lr, optimizer, epoch, args.lr_decay)
        print(("Epoch {} Training Starting").format(epoch))
        print("Learning Rate : ", utils.get_lr(optimizer))

        train_loss = train_kd.train(student_model, teacher_model, optimizer,
                                    loss_fn, train_loader, use_gpu, epoch,
                                    writer, args.temperature, args.gamma)
        val_loss = train_kd.validate(student_model, simple_loss_fn,
                                     test_loader, use_gpu, epoch, writer)

        print("-" * 50)
        print(("Epoch {}, Training-Loss: {}, Validation-Loss: {}").format(
            epoch, train_loss, val_loss))
        print("=" * 50)

        curr_state = {
            "epoch": epoch,
            "best_loss": min(best_loss, val_loss),
            "model": student_model.state_dict()
        }

        # # Use only if model to be saved at each epoch
        # filename = 'epoch_' + str(epoch) + '_checkpoint.pth.tar'

        utils.save_checkpoint(
            state=curr_state,
            is_best=bool(val_loss < best_loss),
            dir_name=dir_name,
            # filename=filename
        )

        if val_loss < best_loss:
            best_loss = val_loss
        epoch += 1
        writer.add_scalar('data/learning_rate', utils.get_lr(optimizer), epoch)
    def train(self):
        optimizer_ae = Adam(chain(self.Encoder.parameters(),
                                  self.Decoder.parameters()),
                            self.lr,
                            betas=(self.b1, self.b2),
                            weight_decay=self.weight_decay)
        optimizer_discriminator = Adam(self.Disciminator.parameters(),
                                       self.lr,
                                       betas=(self.b1, self.b2),
                                       weight_decay=self.weight_decay)
        lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer_ae,
            LambdaLR(self.num_epoch, self.epoch, self.decay_epoch).step)
        total_step = len(self.data_loader)

        perceptual_criterion = PerceptualLoss().to(self.device)
        content_criterion = nn.L1Loss().to(self.device)
        adversarial_criterion = nn.BCELoss().to(self.device)

        self.Encoder.train()
        self.Decoder.train()
        content_losses = AverageMeter()
        generator_losses = AverageMeter()
        perceptual_losses = AverageMeter()
        discriminator_losses = AverageMeter()
        ae_losses = AverageMeter()

        lr_window = create_vis_plot('Epoch', 'Learning rate', 'Learning rate')
        loss_window = create_vis_plot('Epoch', 'Loss', 'Total Loss')
        generator_loss_window = create_vis_plot('Epoch', 'Loss',
                                                'Generator Loss')
        discriminator_loss_window = create_vis_plot('Epoch', 'Loss',
                                                    'Discriminator Loss')
        content_loss_window = create_vis_plot('Epoch', 'Loss', 'Content Loss')
        perceptual_loss_window = create_vis_plot('Epoch', 'Loss',
                                                 'Perceptual Loss')

        if not os.path.exists(self.sample_dir):
            os.makedirs(self.sample_dir)
        if not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)

        for epoch in range(self.epoch, self.num_epoch):
            content_losses.reset()
            perceptual_losses.reset()
            generator_losses.reset()
            ae_losses.reset()
            discriminator_losses.reset()
            for step, images in enumerate(self.data_loader):
                images = images.to(self.device)

                real_labels = torch.ones((images.size(0), 1)).to(self.device)
                fake_labels = torch.zeros((images.size(0), 1)).to(self.device)

                encoded_image = self.Encoder(images)

                binary_decoded_image = paq.compress(
                    encoded_image.cpu().detach().numpy().tobytes())
                # encoded_image = paq.decompress(binary_decoded_image)
                #
                # encoded_image = torch.from_numpy(np.frombuffer(encoded_image, dtype=np.float32)
                #                                  .reshape(-1, self.storing_channels, self.image_size // 8,
                #                                           self.image_size // 8)).to(self.device)

                decoded_image = self.Decoder(encoded_image)

                content_loss = content_criterion(images, decoded_image)
                perceptual_loss = perceptual_criterion(images, decoded_image)
                generator_loss = adversarial_criterion(
                    self.Disciminator(decoded_image), real_labels)
                # generator_loss = -self.Disciminator(decoded_image).mean()

                ae_loss = content_loss * self.content_loss_factor + perceptual_loss * self.perceptual_loss_factor + \
                          generator_loss * self.generator_loss_factor

                content_losses.update(content_loss.item())
                perceptual_losses.update(perceptual_loss.item())
                generator_losses.update(generator_loss.item())
                ae_losses.update(ae_loss.item())

                optimizer_ae.zero_grad()
                ae_loss.backward(retain_graph=True)
                optimizer_ae.step()

                interpolated_image = self.eta * images + (
                    1 - self.eta) * decoded_image
                gravity_penalty = self.Disciminator(interpolated_image).mean()
                real_loss = adversarial_criterion(self.Disciminator(images),
                                                  real_labels)
                fake_loss = adversarial_criterion(
                    self.Disciminator(decoded_image), fake_labels)
                discriminator_loss = (real_loss + fake_loss) * self.discriminator_loss_factor / 2 +\
                                     gravity_penalty * self.penalty_loss_factor

                # discriminator_loss = self.Disciminator(decoded_image).mean() - self.Disciminator(images).mean() + \
                #                      gravity_penalty * self.penalty_loss_factor

                optimizer_discriminator.zero_grad()
                discriminator_loss.backward(retain_graph=True)
                optimizer_discriminator.step()
                discriminator_losses.update(discriminator_loss.item())

                if step % 100 == 0:
                    print(
                        f"[Epoch {epoch}/{self.num_epoch}] [Batch {step}/{total_step}] [Learning rate {get_lr(optimizer_ae)}] "
                        f"[Content {content_loss:.4f}] [Perceptual {perceptual_loss:.4f}] [Gan {generator_loss:.4f}]"
                        f"[Discriminator {discriminator_loss:.4f}]")

                    save_image(
                        torch.cat([images, decoded_image], dim=2),
                        os.path.join(self.sample_dir,
                                     f"Sample-epoch-{epoch}-step-{step}.png"))

            update_vis_plot(epoch, ae_losses.avg, loss_window, 'append')
            update_vis_plot(epoch, generator_losses.avg, generator_loss_window,
                            'append')
            update_vis_plot(epoch, discriminator_losses.avg,
                            discriminator_loss_window, 'append')
            update_vis_plot(epoch, content_losses.avg, content_loss_window,
                            'append')
            update_vis_plot(epoch, perceptual_losses.avg,
                            perceptual_loss_window, 'append')
            update_vis_plot(epoch, get_lr(optimizer_ae), lr_window, 'append')

            lr_scheduler.step()

            torch.save(
                self.Encoder.state_dict(),
                os.path.join(self.checkpoint_dir, f"Encoder-{epoch}.pth"))
            torch.save(
                self.Decoder.state_dict(),
                os.path.join(self.checkpoint_dir, f"Decoder-{epoch}.pth"))
            torch.save(
                self.Disciminator.state_dict(),
                os.path.join(self.checkpoint_dir,
                             f"Discriminator-{epoch}.pth"))
Ejemplo n.º 12
0
def fit_one_epoch(model_train, model, focal_loss, loss_history, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda):
    loss        = 0
    val_loss    = 0

    model_train.train()
    print('Start Train')
    with tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) as pbar:
        for iteration, batch in enumerate(gen):
            if iteration >= epoch_step:
                break

            images, targets = batch[0], batch[1]
            with torch.no_grad():
                if cuda:
                    images  = torch.from_numpy(images).type(torch.FloatTensor).cuda()
                    targets = [torch.from_numpy(ann).type(torch.FloatTensor).cuda() for ann in targets]
                else:
                    images  = torch.from_numpy(images).type(torch.FloatTensor)
                    targets = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in targets]
            #----------------------#
            #   清零梯度
            #----------------------#
            optimizer.zero_grad()
            #-------------------#
            #   获得预测结果
            #-------------------#
            _, regression, classification, anchors = model_train(images)
            #-------------------#
            #   计算损失
            #-------------------#
            loss_value, _, _ = focal_loss(classification, regression, anchors, targets, cuda=cuda)

            loss_value.backward()
            torch.nn.utils.clip_grad_norm_(model_train.parameters(), 1e-2)
            optimizer.step()
            
            loss += loss_value.item()
            
            pbar.set_postfix(**{'loss'  : loss / (iteration + 1), 
                                'lr'    : get_lr(optimizer)})
            pbar.update(1)

    print('Finish Train')

    model_train.eval()
    print('Start Validation')
    with tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) as pbar:
        for iteration, batch in enumerate(gen_val):
            if iteration >= epoch_step_val:
                break
            images, targets = batch[0], batch[1]
            with torch.no_grad():
                if cuda:
                    images  = torch.from_numpy(images).type(torch.FloatTensor).cuda()
                    targets = [torch.from_numpy(ann).type(torch.FloatTensor).cuda() for ann in targets]
                else:
                    images  = torch.from_numpy(images).type(torch.FloatTensor)
                    targets = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in targets]
                #----------------------#
                #   清零梯度
                #----------------------#
                optimizer.zero_grad()
                #-------------------#
                #   获得预测结果
                #-------------------#
                _, regression, classification, anchors = model_train(images)
                #-------------------#
                #   计算损失
                #-------------------#
                loss_value, _, _ = focal_loss(classification, regression, anchors, targets, cuda = cuda)

            val_loss += loss_value.item()
            pbar.set_postfix(**{'val_loss': val_loss / (iteration + 1)})
            pbar.update(1)

    print('Finish Validation')
    
    loss_history.append_loss(loss / epoch_step, val_loss / epoch_step_val)
    print('Epoch:'+ str(epoch+1) + '/' + str(Epoch))
    print('Total Loss: %.3f || Val Loss: %.3f ' % (loss / epoch_step, val_loss / epoch_step_val))
    torch.save(model.state_dict(), 'logs/ep%03d-loss%.3f-val_loss%.3f.pth' % (epoch + 1, loss / epoch_step, val_loss / epoch_step_val))
Ejemplo n.º 13
0
def main(args):
    print("Process %s, running on %s: starting (%s)" %
          (os.getpid(), os.name, time.asctime()))
    process_num = round(time.time())
    dir_name = args.name + '_' + str(process_num)
    tb_path = "bleeds_experiments/logs/%s/" % (dir_name)

    writer = SummaryWriter(tb_path)

    use_gpu = args.gpu
    if not torch.cuda.is_available():
        use_gpu = False

    transform = transforms.Compose([
        # transforms.RandomCrop(32, padding=4),
        transforms.CenterCrop(900),
        transforms.Resize((224, 224)),
        # transforms.RandomHorizontalFlip(),
        # transforms.RandomRotation(30),
        # transforms.RandomPerspective(),
        # transforms.ColorJitter(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    val_transform = transforms.Compose([
        transforms.CenterCrop(500),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    train_dataset = BleedsDataset(transform=transform,
                                  mode="train",
                                  dataset_path=_DATASET_PATH,
                                  batch_size=args.batch_size,
                                  upsample=args.upsample)
    val_dataset = BleedsDataset(transform=val_transform,
                                mode="val",
                                dataset_path=_DATASET_PATH)
    test_dataset = BleedsDataset(transform=val_transform,
                                 mode="test",
                                 dataset_path=_DATASET_PATH)

    train_loader = get_loader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True)
    val_loader = get_loader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=True)
    test_loader = get_loader(test_dataset,
                             batch_size=args.batch_size,
                             shuffle=True)

    print("Loaded datasets now loading models")

    encoder = EncoderCNN(pretrained=args.pretrained,
                         base_model=args.base_model)

    if args.seq_model:
        decoder = DecoderLSTM()
    else:
        decoder = Aggregator(max_aggregate=args.max_aggregate)

    if use_gpu:
        cudnn.benchmark = True
        encoder = encoder.cuda()
        decoder = decoder.cuda()

    encoder_trainables = [p for p in encoder.parameters() if p.requires_grad]
    decoder_trainables = [p for p in decoder.parameters() if p.requires_grad]

    params = encoder_trainables + decoder_trainables
    optimizer = torch.optim.SGD(params=params, lr=args.lr, momentum=0.9)
    # optimizer = torch.optim.Adam(params=params, lr=args.lr, betas=(0.9, 0.999), eps=1e-08)
    if args.cyclic_lr:
        scheduler = cyclicLR.CyclicCosAnnealingLR(optimizer,
                                                  milestones=[10, 20],
                                                  eta_min=1e-7)

    loss_fn = torch.nn.BCELoss()
    metrics_fn = utils.find_metrics

    start_epoch, best_loss = utils.load_checkpoint(encoder, decoder,
                                                   args.resume)
    epoch = start_epoch

    while epoch <= int(args.n_epochs):
        print("=" * 50)
        if args.cyclic_lr:
            scheduler.step()
        else:
            utils.adjust_learning_rate(args.lr, optimizer, epoch,
                                       args.lr_decay)

        print("Epoch %d Training Starting" % epoch)
        print("Learning Rate : ", utils.get_lr(optimizer))

        print("\n", "-" * 10, "Training", "-" * 10, "\n")
        train_loss = train.train(train_loader, encoder, decoder, optimizer,
                                 loss_fn, metrics_fn, epoch, writer, use_gpu)

        print("\n", "-" * 10, "Validation", "-" * 10, "\n")
        val_loss = train.validate(val_loader,
                                  encoder,
                                  decoder,
                                  loss_fn,
                                  metrics_fn,
                                  epoch,
                                  writer,
                                  use_gpu,
                                  ver="validation")

        print("-" * 50)
        print("Training Loss: ", float(train_loss))
        print("Validation Loss: ", float(val_loss))

        if epoch % args.test_epoch == 0:
            test_loss = train.validate(test_loader,
                                       encoder,
                                       decoder,
                                       loss_fn,
                                       metrics_fn,
                                       epoch,
                                       writer,
                                       use_gpu,
                                       ver="val")

            print("Test Set Loss Loss: ", float(test_loss))
        print("=" * 50)

        curr_state = state = {
            "epoch": epoch,
            "best_loss": min(best_loss, val_loss),
            "encoder": encoder.state_dict(),
            "decoder": decoder.state_dict()
        }

        # filename = 'epoch_' + str(epoch) + '_checkpoint.pth.tar'

        utils.save_checkpoint(
            state=curr_state,
            is_best=bool(val_loss < best_loss),
            dir_name=dir_name,
            # filename=filename
        )
        if val_loss < best_loss:
            best_loss = val_loss

        epoch += 1
        writer.add_scalar('data/learning_rate', utils.get_lr(optimizer), epoch)
        print(utils.get_lr(optimizer))
Ejemplo n.º 14
0
def fit_one_epoch(model_train,
                  model,
                  loss_history,
                  eval_callback,
                  optimizer,
                  epoch,
                  epoch_step,
                  epoch_step_val,
                  gen,
                  gen_val,
                  Epoch,
                  cuda,
                  fp16,
                  scaler,
                  backbone,
                  save_period,
                  save_dir,
                  local_rank=0):
    total_r_loss = 0
    total_c_loss = 0
    total_loss = 0
    val_loss = 0

    if local_rank == 0:
        print('Start Train')
        pbar = tqdm(total=epoch_step,
                    desc=f'Epoch {epoch + 1}/{Epoch}',
                    postfix=dict,
                    mininterval=0.3)
    model_train.train()
    for iteration, batch in enumerate(gen):
        if iteration >= epoch_step:
            break
        with torch.no_grad():
            if cuda:
                batch = [ann.cuda(local_rank) for ann in batch]
        batch_images, batch_hms, batch_whs, batch_regs, batch_reg_masks = batch

        #----------------------#
        #   清零梯度
        #----------------------#
        optimizer.zero_grad()
        if not fp16:
            if backbone == "resnet50":
                hm, wh, offset = model_train(batch_images)
                c_loss = focal_loss(hm, batch_hms)
                wh_loss = 0.1 * reg_l1_loss(wh, batch_whs, batch_reg_masks)
                off_loss = reg_l1_loss(offset, batch_regs, batch_reg_masks)

                loss = c_loss + wh_loss + off_loss

                total_loss += loss.item()
                total_c_loss += c_loss.item()
                total_r_loss += wh_loss.item() + off_loss.item()
            else:
                outputs = model_train(batch_images)
                loss = 0
                c_loss_all = 0
                r_loss_all = 0
                index = 0
                for output in outputs:
                    hm, wh, offset = output["hm"].sigmoid(
                    ), output["wh"], output["reg"]
                    c_loss = focal_loss(hm, batch_hms)
                    wh_loss = 0.1 * reg_l1_loss(wh, batch_whs, batch_reg_masks)
                    off_loss = reg_l1_loss(offset, batch_regs, batch_reg_masks)

                    loss += c_loss + wh_loss + off_loss

                    c_loss_all += c_loss
                    r_loss_all += wh_loss + off_loss
                    index += 1
                total_loss += loss.item() / index
                total_c_loss += c_loss_all.item() / index
                total_r_loss += r_loss_all.item() / index
            loss.backward()
            optimizer.step()
        else:
            from torch.cuda.amp import autocast
            with autocast():
                if backbone == "resnet50":
                    hm, wh, offset = model_train(batch_images)
                    c_loss = focal_loss(hm, batch_hms)
                    wh_loss = 0.1 * reg_l1_loss(wh, batch_whs, batch_reg_masks)
                    off_loss = reg_l1_loss(offset, batch_regs, batch_reg_masks)

                    loss = c_loss + wh_loss + off_loss

                    total_loss += loss.item()
                    total_c_loss += c_loss.item()
                    total_r_loss += wh_loss.item() + off_loss.item()
                else:
                    outputs = model_train(batch_images)
                    loss = 0
                    c_loss_all = 0
                    r_loss_all = 0
                    index = 0
                    for output in outputs:
                        hm, wh, offset = output["hm"].sigmoid(
                        ), output["wh"], output["reg"]
                        c_loss = focal_loss(hm, batch_hms)
                        wh_loss = 0.1 * reg_l1_loss(wh, batch_whs,
                                                    batch_reg_masks)
                        off_loss = reg_l1_loss(offset, batch_regs,
                                               batch_reg_masks)

                        loss += c_loss + wh_loss + off_loss

                        c_loss_all += c_loss
                        r_loss_all += wh_loss + off_loss
                        index += 1
                    total_loss += loss.item() / index
                    total_c_loss += c_loss_all.item() / index
                    total_r_loss += r_loss_all.item() / index

            #----------------------#
            #   反向传播
            #----------------------#
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        if local_rank == 0:
            pbar.set_postfix(
                **{
                    'total_r_loss': total_r_loss / (iteration + 1),
                    'total_c_loss': total_c_loss / (iteration + 1),
                    'lr': get_lr(optimizer)
                })
            pbar.update(1)

    if local_rank == 0:
        pbar.close()
        print('Finish Train')
        print('Start Validation')
        pbar = tqdm(total=epoch_step_val,
                    desc=f'Epoch {epoch + 1}/{Epoch}',
                    postfix=dict,
                    mininterval=0.3)

    model_train.eval()
    for iteration, batch in enumerate(gen_val):
        if iteration >= epoch_step_val:
            break

        with torch.no_grad():
            if cuda:
                batch = [ann.cuda(local_rank) for ann in batch]
            batch_images, batch_hms, batch_whs, batch_regs, batch_reg_masks = batch

            if backbone == "resnet50":
                hm, wh, offset = model_train(batch_images)
                c_loss = focal_loss(hm, batch_hms)
                wh_loss = 0.1 * reg_l1_loss(wh, batch_whs, batch_reg_masks)
                off_loss = reg_l1_loss(offset, batch_regs, batch_reg_masks)

                loss = c_loss + wh_loss + off_loss

                val_loss += loss.item()
            else:
                outputs = model_train(batch_images)
                index = 0
                loss = 0
                for output in outputs:
                    hm, wh, offset = output["hm"].sigmoid(
                    ), output["wh"], output["reg"]
                    c_loss = focal_loss(hm, batch_hms)
                    wh_loss = 0.1 * reg_l1_loss(wh, batch_whs, batch_reg_masks)
                    off_loss = reg_l1_loss(offset, batch_regs, batch_reg_masks)

                    loss += c_loss + wh_loss + off_loss
                    index += 1
                val_loss += loss.item() / index

            if local_rank == 0:
                pbar.set_postfix(**{'val_loss': val_loss / (iteration + 1)})
                pbar.update(1)

    if local_rank == 0:
        pbar.close()
        print('Finish Validation')
        loss_history.append_loss(epoch + 1, total_loss / epoch_step,
                                 val_loss / epoch_step_val)
        eval_callback.on_epoch_end(epoch + 1, model_train)
        print('Epoch:' + str(epoch + 1) + '/' + str(Epoch))
        print('Total Loss: %.3f || Val Loss: %.3f ' %
              (total_loss / epoch_step, val_loss / epoch_step_val))

        #-----------------------------------------------#
        #   保存权值
        #-----------------------------------------------#
        if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
            torch.save(
                model.state_dict(),
                os.path.join(
                    save_dir, 'ep%03d-loss%.3f-val_loss%.3f.pth' %
                    (epoch + 1, total_loss / epoch_step,
                     val_loss / epoch_step_val)))

        if len(loss_history.val_loss) <= 1 or (
                val_loss / epoch_step_val) <= min(loss_history.val_loss):
            print('Save best model to best_epoch_weights.pth')
            torch.save(model.state_dict(),
                       os.path.join(save_dir, "best_epoch_weights.pth"))

        torch.save(model.state_dict(),
                   os.path.join(save_dir, "last_epoch_weights.pth"))
Ejemplo n.º 15
0
def main():
    print(torch.cuda.device_count())
    global args
    global devices
    global WRITER
    args = parser.parse_args()
    global description
    description = 'bt_%d_seg_%d_%s' % (args.batch_size * ACCUMU_STEPS,
                                       args.num_segments, "finetune_from_vcdb")
    log_name = r'/home/sjhu/projects/compressed_video_compare/imqfusion/log/%s' % description
    WRITER = SummaryWriter(log_name)
    print('Training arguments:')
    for k, v in vars(args).items():
        print('\t{}: {}'.format(k, v))

    model = Model(2,
                  args.num_segments,
                  args.representation,
                  base_model=args.arch)

    # add continue train from before
    if CONTINUE_FROM_LAST:
        checkpoint = torch.load(LAST_SAVE_PATH)
        # print("model epoch {} best prec@1: {}".format(checkpoint['epoch'], checkpoint['best_prec1']))
        print("model epoch {} lowest loss {}".format(checkpoint['epoch'],
                                                     checkpoint['loss_min']))
        base_dict = {
            '.'.join(k.split('.')[1:]): v
            for k, v in list(checkpoint['state_dict'].items())
        }
        loss_min = checkpoint['loss_min']
        model.load_state_dict(base_dict)
        start_epochs = checkpoint['epoch']
    else:
        loss_min = 10000
        start_epochs = 0

    devices = [torch.device("cuda:%d" % device) for device in args.gpus]
    global DEVICES
    DEVICES = devices

    # deal the unbalance between pos and neg samples
    train_dataset = CoviarDataSet(
        args.data_root,
        video_list=args.train_list,
        num_segments=args.num_segments,
        is_train=True,
    )
    target = train_dataset._labels_list
    class_sample_count = torch.tensor([(target == t).sum()
                                       for t in np.unique(target)])
    weight = 1. / class_sample_count.float()
    samples_weights = weight[target]
    train_sampler = WeightedRandomSampler(samples_weights, len(train_dataset),
                                          True)
    train_loader = torch.utils.data.DataLoader(CoviarDataSet(
        args.data_root,
        video_list=args.train_list,
        num_segments=args.num_segments,
        is_train=True,
    ),
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(CoviarDataSet(
        args.data_root,
        video_list=args.test_list,
        num_segments=args.num_segments,
        is_train=False,
    ),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    model = torch.nn.DataParallel(model, device_ids=args.gpus)
    model = model.to(devices[0])
    cudnn.benchmark = True

    params_dict = dict(model.named_parameters())
    params = []
    for key, value in params_dict.items():
        decay_mult = 0.0 if 'bias' in key else 1.0
        if 'module.fc' in key:
            params += [{
                'params': [value],
                'lr': args.lr * 10,
                'decay_mult': decay_mult
            }]
        elif 'module.fusion' in key:
            params += [{
                'params': [value],
                'lr': args.lr * 10,
                'decay_mult': decay_mult
            }]
        elif 'module.mvnet' in key:
            params += [{
                'params': [value],
                'lr': args.lr * 10,
                'decay_mult': decay_mult
            }]
        else:
            params += [{
                'params': [value],
                'lr': args.lr * 1,
                'decay_mult': decay_mult
            }]

    # loss_weights = torch.FloatTensor([1.01,1])
    optimizer = torch.optim.SGD(params, lr=args.lr, momentum=0.9)
    criterions = []
    siamese_loss = ContrastiveLoss(margin=2.0).to(devices[0])
    classifiy_loss = nn.CrossEntropyLoss().to(devices[0])
    # classifiy_loss = LabelSmoothingLoss(2,0.1,-1)
    criterions.append(siamese_loss)
    criterions.append(classifiy_loss)

    # try to use ReduceOnPlatue to adjust lr
    # scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=20 // args.eval_freq, verbose=True)
    scheduler = WarmStartCosineAnnealingLR(optimizer,
                                           T_max=args.epochs,
                                           T_warm=10)
    for epoch in range(start_epochs, args.epochs):
        # about optimizer:
        WRITER.add_scalar('Lr/epoch', get_lr(optimizer), epoch)
        loss_train_s, loss_train_c = train(train_loader, model, criterions,
                                           optimizer, epoch)
        loss_train = WEI_S * loss_train_s + WEI_C * loss_train_c
        scheduler.step(epoch)
        if epoch % EVAL_FREQ == 0 or epoch == args.epochs - 1:
            loss_val_s, loss_val_c, acc, report = validate(
                val_loader, model, criterions, epoch)
            loss_val = WEI_S * loss_val_s + WEI_C * loss_val_c
            is_best = (loss_val_c < loss_min)
            loss_min = min(loss_val_c, loss_min)
            # visualization
            WRITER.add_text(tag='Classification Report',
                            text_string=report,
                            global_step=epoch)
            WRITER.add_scalar('Accuracy/epoch', acc, epoch)
            WRITER.add_scalars('Siamese Loss/epoch', {
                'Train': loss_train_s,
                'Val': loss_val_s
            }, epoch)
            WRITER.add_scalars('Classification Loss/epoch', {
                'Train': loss_train_c,
                'Val': loss_val_c
            }, epoch)
            WRITER.add_scalars('Combine Loss/epoch', {
                'Train': loss_train,
                'Val': loss_val
            }, epoch)
            if is_best or epoch % SAVE_FREQ == 0:
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'loss_min': loss_min,
                    },
                    is_best,
                    filename='checkpoint.pth.tar')
    WRITER.close()