Beispiel #1
0
def test(model, ema, args, data):
    device = torch.device(
        f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    criterion = nn.CrossEntropyLoss()
    loss = 0
    answers = dict()
    model.eval()

    backup_params = EMA(0)
    for name, param in model.named_parameters():
        if param.requires_grad:
            backup_params.register(name, param.data)
            param.data.copy_(ema.get(name))
    with torch.no_grad():
        for batch in iter(data.dev_iter):
            p1, p2 = model(batch)
            batch_loss = criterion(p1, batch.s_idx) + criterion(
                p2, batch.e_idx)
            loss += batch_loss.item()

            # (batch, c_len, c_len)
            batch_size, c_len = p1.size()
            ls = nn.LogSoftmax(dim=1)
            mask = (torch.ones(c_len, c_len) *
                    float('-inf')).to(device).tril(-1).unsqueeze(0).expand(
                        batch_size, -1, -1)
            score = (ls(p1).unsqueeze(2) + ls(p2).unsqueeze(1)) + mask
            score, s_idx = score.max(dim=1)
            score, e_idx = score.max(dim=1)
            s_idx = torch.gather(s_idx, 1, e_idx.view(-1, 1)).squeeze()

            for i in range(batch_size):
                id = batch.id[i]
                answer = batch.c_word[0][i][s_idx[i]:e_idx[i] + 1]
                answer = ' '.join(
                    [data.WORD.vocab.itos[idx] for idx in answer])
                answers[id] = answer

        for name, param in model.named_parameters():
            if param.requires_grad:
                param.data.copy_(backup_params.get(name))

        with open(args.prediction_file, 'w', encoding='utf-8') as f:
            print(json.dumps(answers), file=f)

        results = evaluate.main(args)
    return loss, results['exact_match'], results['f1']
Beispiel #2
0
def get_vis(args,data):
	device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
	model = BiDAF(args,data.WORD.vocab.vectors).to(device)
	print ("load Pretrained model")
	load_path = torch.load(args.load_path)
	model.load_state_dict(load_path)
	ema = EMA(args.exp_decay_rate)
	criterion = nn.CrossEntropyLoss()
	model.save_mode = True

	model.eval()
	def save_vis_data(train=True):
		save_data = []

		iterator = data.train_iter if train else data.dev_iter
		mode = 'trainData'if train else 'testData'
		print ('Mode :{}'.format(mode))
		save_count = 0
		with torch.no_grad() :
			count =0
			for i,batch in enumerate(iterator):
				present_epoch = int(iterator.epoch)
				if present_epoch == 1:
					break;
				tmp = {}
				p1,p2 = model(batch)
				p1 = p1.unsqueeze(0)
				p2 = p2.unsqueeze(0)
				c_words = [data.WORD.vocab.itos[i.item()] for i in batch.c_word[0][0]]
				q_words = [data.WORD.vocab.itos[i.item()] for i in batch.q_word[0][0]]
				batch_loss = criterion(p1,batch.s_idx) + criterion(p2,batch.e_idx)
				batch_size, c_len = p1.size()

				ls = nn.LogSoftmax(dim=1)
				mask = (torch.ones(c_len, c_len) * float('-inf')).to(device).tril(-1).unsqueeze(0).expand(batch_size, -1, -1)

				score = (ls(p1).unsqueeze(2) + ls(p2).unsqueeze(1)) + mask
				scores = score
				score, s_idx = score.max(dim=1)
				score, e_idx = score.max(dim=1)
				s_idx = torch.gather(s_idx, 1, e_idx.view(-1, 1)).squeeze()

				tmp['context'] = c_words
				tmp['question'] = q_words
				tmp['gt_s_idx'] = batch.s_idx.cpu().numpy()
				tmp['gt_e_idx'] = batch.e_idx.cpu().numpy()
				tmp['save_data'] = model.save_data.copy()
				tmp['loss'] = batch_loss.cpu().numpy()
				tmp['prediction_s_idx'] = s_idx.cpu().numpy()
				tmp['prediction_e_idx'] = e_idx.cpu().numpy()
				tmp['prediction_scores'] = scores.cpu().numpy()
				save_data.append(tmp)
				if len(save_data)%2000 ==0:
					np.save('{}_{}'.format(count,mode),save_data)
					save_data = []
					count +=1
			np.save('{}_{}'.format(count,mode),save_data)
	save_vis_data()
	save_vis_data(False)
def main():
    parser = argparse.ArgumentParser(description=' FixMatch Training')
    parser.add_argument('--wresnet-k',
                        default=2,
                        type=int,
                        help='width factor of wide resnet')
    parser.add_argument('--wresnet-n',
                        default=28,
                        type=int,
                        help='depth of wide resnet')
    parser.add_argument('--dataset',
                        type=str,
                        default='CIFAR10',
                        help='number of classes in dataset')
    # parser.add_argument('--n-classes', type=int, default=100,
    #                     help='number of classes in dataset')
    parser.add_argument('--n-labeled',
                        type=int,
                        default=40,
                        help='number of labeled samples for training')
    parser.add_argument('--n-epoches',
                        type=int,
                        default=1024,
                        help='number of training epoches')
    parser.add_argument('--batchsize',
                        type=int,
                        default=40,
                        help='train batch size of labeled samples')
    parser.add_argument('--mu',
                        type=int,
                        default=7,
                        help='factor of train batch size of unlabeled samples')
    parser.add_argument('--thr',
                        type=float,
                        default=0.95,
                        help='pseudo label threshold')
    parser.add_argument('--n-imgs-per-epoch',
                        type=int,
                        default=64 * 1024,
                        help='number of training images for each epoch')
    parser.add_argument('--lam-u',
                        type=float,
                        default=1.,
                        help='coefficient of unlabeled loss')
    parser.add_argument('--ema-alpha',
                        type=float,
                        default=0.999,
                        help='decay rate for ema module')
    parser.add_argument('--lr',
                        type=float,
                        default=0.03,
                        help='learning rate for training')
    parser.add_argument('--weight-decay',
                        type=float,
                        default=5e-4,
                        help='weight decay')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        help='momentum for optimizer')
    parser.add_argument('--seed',
                        type=int,
                        default=-1,
                        help='seed for random behaviors, no seed if negtive')
    parser.add_argument('--temperature',
                        type=float,
                        default=0.5,
                        help='temperature for loss function')
    args = parser.parse_args()

    # args.device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

    logger, writer = setup_default_logging(args)
    logger.info(dict(args._get_kwargs()))

    # global settings
    #  torch.multiprocessing.set_sharing_strategy('file_system')
    if args.seed > 0:
        torch.manual_seed(args.seed)
        random.seed(args.seed)
        np.random.seed(args.seed)
        # torch.backends.cudnn.deterministic = True

    n_iters_per_epoch = args.n_imgs_per_epoch // args.batchsize  # 1024
    n_iters_all = n_iters_per_epoch * args.n_epoches  # 1024 * 1024

    logger.info("***** Running training *****")
    logger.info(f"  Task = {args.dataset}@{args.n_labeled}")
    logger.info(f"  Num Epochs = {n_iters_per_epoch}")
    logger.info(f"  Batch size per GPU = {args.batchsize}")
    # logger.info(f"  Total train batch size = {args.batch_size * args.world_size}")
    logger.info(f"  Total optimization steps = {n_iters_all}")

    model, criteria_x, criteria_u, criteria_z = set_model(args)
    logger.info("Total params: {:.2f}M".format(
        sum(p.numel() for p in model.parameters()) / 1e6))

    dltrain_x, dltrain_u = get_train_loader(args.dataset,
                                            args.batchsize,
                                            args.mu,
                                            n_iters_per_epoch,
                                            L=args.n_labeled)
    dlval = get_val_loader(dataset=args.dataset, batch_size=64, num_workers=2)

    lb_guessor = LabelGuessor(thresh=args.thr)

    ema = EMA(model, args.ema_alpha)

    wd_params, non_wd_params = [], []
    for name, param in model.named_parameters():
        # if len(param.size()) == 1:
        if 'bn' in name:
            non_wd_params.append(
                param)  # bn.weight, bn.bias and classifier.bias
            # print(name)
        else:
            wd_params.append(param)
    param_list = [{
        'params': wd_params
    }, {
        'params': non_wd_params,
        'weight_decay': 0
    }]
    optim = torch.optim.SGD(param_list,
                            lr=args.lr,
                            weight_decay=args.weight_decay,
                            momentum=args.momentum,
                            nesterov=True)
    lr_schdlr = WarmupCosineLrScheduler(optim,
                                        max_iter=n_iters_all,
                                        warmup_iter=0)

    train_args = dict(model=model,
                      criteria_x=criteria_x,
                      criteria_u=criteria_u,
                      criteria_z=criteria_z,
                      optim=optim,
                      lr_schdlr=lr_schdlr,
                      ema=ema,
                      dltrain_x=dltrain_x,
                      dltrain_u=dltrain_u,
                      lb_guessor=lb_guessor,
                      lambda_u=args.lam_u,
                      n_iters=n_iters_per_epoch,
                      logger=logger)
    best_acc = -1
    best_epoch = 0
    logger.info('-----------start training--------------')
    for epoch in range(args.n_epoches):
        train_loss, loss_x, loss_u, loss_u_real, loss_simclr, mask_mean = train_one_epoch(
            epoch, **train_args)
        # torch.cuda.empty_cache()

        top1, top5, valid_loss = evaluate(ema, dlval, criteria_x)

        writer.add_scalars('train/1.loss', {
            'train': train_loss,
            'test': valid_loss
        }, epoch)
        writer.add_scalar('train/2.train_loss_x', loss_x, epoch)
        writer.add_scalar('train/3.train_loss_u', loss_u, epoch)
        writer.add_scalar('train/4.train_loss_u_real', loss_u_real, epoch)
        writer.add_scalar('train/4.train_loss_simclr', loss_simclr, epoch)
        writer.add_scalar('train/5.mask_mean', mask_mean, epoch)
        writer.add_scalars('test/1.test_acc', {
            'top1': top1,
            'top5': top5
        }, epoch)
        # writer.add_scalar('test/2.test_loss', loss, epoch)

        # best_acc = top1 if best_acc < top1 else best_acc
        if best_acc < top1:
            best_acc = top1
            best_epoch = epoch

        logger.info(
            "Epoch {}. Top1: {:.4f}. Top5: {:.4f}. best_acc: {:.4f} in epoch{}"
            .format(epoch, top1, top5, best_acc, best_epoch))

    writer.close()
Beispiel #4
0
def train(args, data):
    device = torch.device(
        f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    model = BiDAF(args, data.WORD.vocab.vectors).to(device)

    ema = EMA(args.exp_decay_rate)
    for name, param in model.named_parameters():
        if param.requires_grad:
            ema.register(name, param.data)
    parameters = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = optim.Adadelta(parameters, lr=args.learning_rate)
    criterion = nn.CrossEntropyLoss()

    writer = SummaryWriter(log_dir='runs/' + args.model_time)

    model.train()
    loss, last_epoch = 0, -1
    max_dev_exact, max_dev_f1 = -1, -1

    iterator = data.train_iter
    for i, batch in enumerate(iterator):
        present_epoch = int(iterator.epoch)
        if present_epoch == args.epoch:
            break
        if present_epoch > last_epoch:
            print('epoch:', present_epoch + 1)
        last_epoch = present_epoch

        p1, p2 = model(batch)

        optimizer.zero_grad()
        batch_loss = criterion(p1, batch.s_idx) + criterion(p2, batch.e_idx)
        loss += batch_loss.item()
        batch_loss.backward()
        optimizer.step()

        for name, param in model.named_parameters():
            if param.requires_grad:
                ema.update(name, param.data)

        if (i + 1) % args.print_freq == 0:
            dev_loss, dev_exact, dev_f1 = test(model, ema, args, data)
            c = (i + 1) // args.print_freq

            writer.add_scalar('loss/train', loss, c)
            writer.add_scalar('loss/dev', dev_loss, c)
            writer.add_scalar('exact_match/dev', dev_exact, c)
            writer.add_scalar('f1/dev', dev_f1, c)
            print(f'train loss: {loss:.3f} / dev loss: {dev_loss:.3f}'
                  f' / dev EM: {dev_exact:.3f} / dev F1: {dev_f1:.3f}')

            if dev_f1 > max_dev_f1:
                max_dev_f1 = dev_f1
                max_dev_exact = dev_exact
                best_model = copy.deepcopy(model)

            loss = 0
            model.train()

    writer.close()
    print(f'max dev EM: {max_dev_exact:.3f} / max dev F1: {max_dev_f1:.3f}')

    return best_model
def main():
    parser = argparse.ArgumentParser(description=' FixMatch Training')
    parser.add_argument('--wresnet-k',
                        default=2,
                        type=int,
                        help='width factor of wide resnet')
    parser.add_argument('--wresnet-n',
                        default=28,
                        type=int,
                        help='depth of wide resnet')
    parser.add_argument('--dataset',
                        type=str,
                        default='CIFAR10',
                        help='number of classes in dataset')
    # parser.add_argument('--n-classes', type=int, default=100,
    #                     help='number of classes in dataset')
    parser.add_argument('--n-labeled',
                        type=int,
                        default=40,
                        help='number of labeled samples for training')
    parser.add_argument('--n-epoches',
                        type=int,
                        default=1024,
                        help='number of training epoches')
    parser.add_argument('--batchsize',
                        type=int,
                        default=40,
                        help='train batch size of labeled samples')
    parser.add_argument('--mu',
                        type=int,
                        default=7,
                        help='factor of train batch size of unlabeled samples')
    parser.add_argument('--thr',
                        type=float,
                        default=0.95,
                        help='pseudo label threshold')
    parser.add_argument('--n-imgs-per-epoch',
                        type=int,
                        default=64 * 1024,
                        help='number of training images for each epoch')
    parser.add_argument('--lam-u',
                        type=float,
                        default=1.,
                        help='coefficient of unlabeled loss')
    parser.add_argument('--lam-s',
                        type=float,
                        default=0.2,
                        help='coefficient of unlabeled loss SimCLR')
    parser.add_argument('--ema-alpha',
                        type=float,
                        default=0.999,
                        help='decay rate for ema module')
    parser.add_argument('--lr',
                        type=float,
                        default=0.03,
                        help='learning rate for training')
    parser.add_argument('--weight-decay',
                        type=float,
                        default=5e-4,
                        help='weight decay')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        help='momentum for optimizer')
    parser.add_argument('--seed',
                        type=int,
                        default=-1,
                        help='seed for random behaviors, no seed if negtive')
    parser.add_argument('--temperature',
                        type=float,
                        default=0.5,
                        help='temperature for loss function')
    args = parser.parse_args()

    # args.device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

    logger, writer = setup_default_logging(args)
    logger.info(dict(args._get_kwargs()))

    # global settings
    #  torch.multiprocessing.set_sharing_strategy('file_system')
    if args.seed > 0:
        torch.manual_seed(args.seed)
        random.seed(args.seed)
        np.random.seed(args.seed)
        # torch.backends.cudnn.deterministic = True

    n_iters_per_epoch = args.n_imgs_per_epoch // args.batchsize  # 1024
    n_iters_all = n_iters_per_epoch * args.n_epoches  # 1024 * 1024

    logger.info("***** Running training *****")
    logger.info(f"  Task = {args.dataset}@{args.n_labeled}")
    logger.info(f"  Num Epochs = {n_iters_per_epoch}")
    logger.info(f"  Batch size per GPU = {args.batchsize}")
    # logger.info(f"  Total train batch size = {args.batch_size * args.world_size}")
    logger.info(f"  Total optimization steps = {n_iters_all}")

    model, criteria_x, criteria_u, criteria_z = set_model(args)

    logger.info("Total params: {:.2f}M".format(
        sum(p.numel() for p in model.parameters()) / 1e6))

    dltrain_x, dltrain_u, dltrain_f = get_train_loader_mix(args.dataset,
                                                           args.batchsize,
                                                           args.mu,
                                                           n_iters_per_epoch,
                                                           L=args.n_labeled)
    dlval = get_val_loader(dataset=args.dataset, batch_size=64, num_workers=2)

    lb_guessor = LabelGuessor(thresh=args.thr)

    ema = EMA(model, args.ema_alpha)

    wd_params, non_wd_params = [], []
    for name, param in model.named_parameters():
        # if len(param.size()) == 1:
        if 'bn' in name:
            non_wd_params.append(
                param)  # bn.weight, bn.bias and classifier.bias
            # print(name)
        else:
            wd_params.append(param)

    param_list = [{
        'params': wd_params
    }, {
        'params': non_wd_params,
        'weight_decay': 0
    }]

    optim_fix = torch.optim.SGD(param_list,
                                lr=args.lr,
                                weight_decay=args.weight_decay,
                                momentum=args.momentum,
                                nesterov=True)
    lr_schdlr_fix = WarmupCosineLrScheduler(optim_fix,
                                            max_iter=n_iters_all,
                                            warmup_iter=0)

    train_args = dict(model=model,
                      criteria_x=criteria_x,
                      criteria_u=criteria_u,
                      criteria_z=criteria_z,
                      optim=optim_fix,
                      lr_schdlr=lr_schdlr_fix,
                      ema=ema,
                      dltrain_x=dltrain_x,
                      dltrain_u=dltrain_u,
                      dltrain_f=dltrain_f,
                      lb_guessor=lb_guessor,
                      lambda_u=args.lam_u,
                      lambda_s=args.lam_s,
                      n_iters=n_iters_per_epoch,
                      logger=logger,
                      bt=args.batchsize,
                      mu=args.mu)

    # # TRAINING PARAMETERS FOR SIMCLR
    # param_list = [
    #     {'params': wd_params}, {'params': non_wd_params, 'weight_decay': 0}]
    #
    # optim_simclr = torch.optim.SGD(param_list, lr=0.5, weight_decay=args.weight_decay,
    #                                momentum=args.momentum, nesterov=False)
    #
    # lr_schdlr_simclr = WarmupCosineLrScheduler(
    #     optim_simclr, max_iter=n_iters_all, warmup_iter=0
    # )
    #
    # train_args_simclr = dict(
    #     model=model,
    #     criteria_z=criteria_z,
    #     optim=optim_simclr,
    #     lr_schdlr=lr_schdlr_simclr,
    #     ema=ema,
    #     dltrain_f=dltrain_f,
    #     lambda_s=args.lam_s,
    #     n_iters=n_iters_per_epoch,
    #     logger=logger,
    #     bt=args.batchsize,
    #     mu=args.mu
    # )

    # # TRAINING PARAMETERS FOR IIC
    # param_list = [
    #     {'params': wd_params}, {'params': non_wd_params, 'weight_decay': 0}]
    #
    # optim_iic = torch.optim.Adam(param_list, lr=1e-4, weight_decay=args.weight_decay)
    #
    # lr_schdlr_iic = WarmupCosineLrScheduler(
    #     optim_iic, max_iter=n_iters_all, warmup_iter=0
    # )
    #
    # train_args_iic = dict(
    #     model=model,
    #     optim=optim_iic,
    #     lr_schdlr=lr_schdlr_iic,
    #     ema=ema,
    #     dltrain_f=dltrain_f,
    #     n_iters=n_iters_per_epoch,
    #     logger=logger,
    #     bt=args.batchsize,
    #     mu=args.mu
    # )
    #

    best_acc = -1
    best_epoch = 0
    logger.info('-----------start training--------------')

    for epoch in range(args.n_epoches):
        # guardar accuracy de modelo preentrenado hasta espacio h (SALIDA DE BACKBONE)
        top1, top5, valid_loss = evaluate_linear_Clf(ema, dltrain_x, dlval,
                                                     criteria_x)
        writer.add_scalars('test/1.test_linear_acc', {
            'top1': top1,
            'top5': top5
        }, epoch)

        logger.info("Epoch {}. on h space Top1: {:.4f}. Top5: {:.4f}.".format(
            epoch, top1, top5))

        if epoch < -500:
            # # FASE DE ENTRENAMIENTO NO SUPERVISADO
            # entrenar feature representation simclr
            # train_loss, loss_simclr, model_ = train_one_epoch_simclr(epoch, **train_args_simclr)
            # writer.add_scalar('train/4.train_loss_simclr', loss_simclr, epoch)

            # entrenar iic
            # train_loss, loss_iic, model_ = train_one_epoch_iic(epoch, **train_args_iic)
            # writer.add_scalar('train/4.train_loss_iic', loss_iic, epoch)
            # evaluate_Clf(model_, dltrain_f, dlval, criteria_x)

            top1, top5, valid_loss = evaluate_linear_Clf(
                ema, dltrain_x, dlval, criteria_x)
            # # GUARDAR MODELO ENTRENADO DE FORMA NO SUPERVISADA
            # if epoch == 497:
            #     # save model
            #     name = 'simclr_trained_good_h2.pt'
            #     torch.save(model_.state_dict(), name)
            #     logger.info('model saved')

        else:
            # ENTRENAMIENTO SEMI-SUPERVISADO
            train_loss, loss_x, loss_u, loss_u_real, mask_mean, loss_simclr = train_one_epoch(
                epoch, **train_args)
            top1, top5, valid_loss = evaluate(ema, dlval, criteria_x)

            writer.add_scalar('train/4.train_loss_simclr', loss_simclr, epoch)
            writer.add_scalar('train/2.train_loss_x', loss_x, epoch)
            writer.add_scalar('train/3.train_loss_u', loss_u, epoch)
            writer.add_scalar('train/4.train_loss_u_real', loss_u_real, epoch)
            writer.add_scalar('train/5.mask_mean', mask_mean, epoch)

        writer.add_scalars('train/1.loss', {
            'train': train_loss,
            'test': valid_loss
        }, epoch)
        writer.add_scalars('test/1.test_acc', {
            'top1': top1,
            'top5': top5
        }, epoch)
        # writer.add_scalar('test/2.test_loss', loss, epoch)

        # best_acc = top1 if best_acc < top1 else best_acc
        if best_acc < top1:
            best_acc = top1
            best_epoch = epoch

        logger.info(
            "Epoch {}. Top1: {:.4f}. Top5: {:.4f}. best_acc: {:.4f} in epoch{}"
            .format(epoch, top1, top5, best_acc, best_epoch))

    writer.close()
model = SE_GCN(args,
               W2V,
               MAX_LEN,
               embed_size,
               nfeat=v_texts_w2v_idxs_l_list[0].shape[1],
               nfeat_v=v_features_list[0].shape[1],
               nfeat_g=len(g_features[0]),
               nhid_vfeat=args.hidden_vfeat,
               nhid_siamese=args.hidden_siamese,
               dropout_vfeat=args.dropout_vfeat,
               dropout_siamese=args.dropout_siamese,
               nhid_final=args.hidden_final)
summarize_model(model)

if args.use_ema:
    ema = EMA(args.ema_decay)
    ema.register(model)

# optimizer and scheduler
parameters = filter(lambda p: p.requires_grad, model.parameters())
# optimizer = optim.SGD(parameters, lr=args.lr, momentum=0.9)
optimizer = optim.Adam(params=parameters,
                       lr=args.lr,
                       betas=(args.beta1, args.beta2),
                       eps=1e-8,
                       weight_decay=3e-7)
cr = 1.0 / math.log(args.lr_warm_up_num)
scheduler = None
scheduler = optim.lr_scheduler.LambdaLR(
    optimizer,
    lr_lambda=lambda ee: cr * math.log(ee + 1)
Beispiel #7
0
    def __init__(self, cfg, logger, writer):

        # Args
        self.cfg = cfg
        self.device = torch.device('cuda')
        self.logger = logger
        self.writer = writer

        # Counters
        self.epoch = 0
        self.iter = 0
        self.current_MIoU = 0
        self.best_MIou = 0
        self.best_source_MIou = 0

        # Metrics
        self.evaluator = Eval(self.cfg.data.num_classes)

        # Loss
        self.ignore_index = -1
        self.loss = nn.CrossEntropyLoss(ignore_index=self.ignore_index)

        # Model
        self.model, params = get_model(self.cfg)
        # self.model = nn.DataParallel(self.model, device_ids=[0])  # TODO: test multi-gpu
        self.model.to(self.device)

        # EMA
        self.ema = EMA(self.model, self.cfg.ema_decay)

        # Optimizer
        if self.cfg.opt.kind == "SGD":
            self.optimizer = torch.optim.SGD(
                params,
                momentum=self.cfg.opt.momentum,
                weight_decay=self.cfg.opt.weight_decay)
        elif self.cfg.opt.kind == "Adam":
            self.optimizer = torch.optim.Adam(
                params,
                betas=(0.9, 0.99),
                weight_decay=self.cfg.opt.weight_decay)
        else:
            raise NotImplementedError()
        self.lr_factor = 10

        # Source
        if self.cfg.data.source.dataset == 'synthia':
            source_train_dataset = SYNTHIA_Dataset(
                split='train', **self.cfg.data.source.kwargs)
            source_val_dataset = SYNTHIA_Dataset(split='val',
                                                 **self.cfg.data.source.kwargs)
        elif self.cfg.data.source.dataset == 'gta5':
            source_train_dataset = GTA5_Dataset(split='train',
                                                **self.cfg.data.source.kwargs)
            source_val_dataset = GTA5_Dataset(split='val',
                                              **self.cfg.data.source.kwargs)
        else:
            raise NotImplementedError()
        self.source_dataloader = DataLoader(source_train_dataset,
                                            shuffle=True,
                                            drop_last=True,
                                            **self.cfg.data.loader.kwargs)
        self.source_val_dataloader = DataLoader(source_val_dataset,
                                                shuffle=False,
                                                drop_last=False,
                                                **self.cfg.data.loader.kwargs)

        # Target
        if self.cfg.data.target.dataset == 'cityscapes':
            target_train_dataset = City_Dataset(split='train',
                                                **self.cfg.data.target.kwargs)
            target_val_dataset = City_Dataset(split='val',
                                              **self.cfg.data.target.kwargs)
        else:
            raise NotImplementedError()
        self.target_dataloader = DataLoader(target_train_dataset,
                                            shuffle=True,
                                            drop_last=True,
                                            **self.cfg.data.loader.kwargs)
        self.target_val_dataloader = DataLoader(target_val_dataset,
                                                shuffle=False,
                                                drop_last=False,
                                                **self.cfg.data.loader.kwargs)

        # Perturbations
        if self.cfg.lam_aug > 0:
            self.aug = get_augmentation()
Beispiel #8
0
class Trainer():
    def __init__(self, cfg, logger, writer):

        # Args
        self.cfg = cfg
        self.device = torch.device('cuda')
        self.logger = logger
        self.writer = writer

        # Counters
        self.epoch = 0
        self.iter = 0
        self.current_MIoU = 0
        self.best_MIou = 0
        self.best_source_MIou = 0

        # Metrics
        self.evaluator = Eval(self.cfg.data.num_classes)

        # Loss
        self.ignore_index = -1
        self.loss = nn.CrossEntropyLoss(ignore_index=self.ignore_index)

        # Model
        self.model, params = get_model(self.cfg)
        # self.model = nn.DataParallel(self.model, device_ids=[0])  # TODO: test multi-gpu
        self.model.to(self.device)

        # EMA
        self.ema = EMA(self.model, self.cfg.ema_decay)

        # Optimizer
        if self.cfg.opt.kind == "SGD":
            self.optimizer = torch.optim.SGD(
                params,
                momentum=self.cfg.opt.momentum,
                weight_decay=self.cfg.opt.weight_decay)
        elif self.cfg.opt.kind == "Adam":
            self.optimizer = torch.optim.Adam(
                params,
                betas=(0.9, 0.99),
                weight_decay=self.cfg.opt.weight_decay)
        else:
            raise NotImplementedError()
        self.lr_factor = 10

        # Source
        if self.cfg.data.source.dataset == 'synthia':
            source_train_dataset = SYNTHIA_Dataset(
                split='train', **self.cfg.data.source.kwargs)
            source_val_dataset = SYNTHIA_Dataset(split='val',
                                                 **self.cfg.data.source.kwargs)
        elif self.cfg.data.source.dataset == 'gta5':
            source_train_dataset = GTA5_Dataset(split='train',
                                                **self.cfg.data.source.kwargs)
            source_val_dataset = GTA5_Dataset(split='val',
                                              **self.cfg.data.source.kwargs)
        else:
            raise NotImplementedError()
        self.source_dataloader = DataLoader(source_train_dataset,
                                            shuffle=True,
                                            drop_last=True,
                                            **self.cfg.data.loader.kwargs)
        self.source_val_dataloader = DataLoader(source_val_dataset,
                                                shuffle=False,
                                                drop_last=False,
                                                **self.cfg.data.loader.kwargs)

        # Target
        if self.cfg.data.target.dataset == 'cityscapes':
            target_train_dataset = City_Dataset(split='train',
                                                **self.cfg.data.target.kwargs)
            target_val_dataset = City_Dataset(split='val',
                                              **self.cfg.data.target.kwargs)
        else:
            raise NotImplementedError()
        self.target_dataloader = DataLoader(target_train_dataset,
                                            shuffle=True,
                                            drop_last=True,
                                            **self.cfg.data.loader.kwargs)
        self.target_val_dataloader = DataLoader(target_val_dataset,
                                                shuffle=False,
                                                drop_last=False,
                                                **self.cfg.data.loader.kwargs)

        # Perturbations
        if self.cfg.lam_aug > 0:
            self.aug = get_augmentation()

    def train(self):

        # Loop over epochs
        self.continue_training = True
        while self.continue_training:

            # Train for a single epoch
            self.train_one_epoch()

            # Use EMA params to evaluate performance
            self.ema.apply_shadow()
            self.ema.model.eval()
            self.ema.model.cuda()

            # Validate on source (if possible) and target
            if self.cfg.data.source_val_iterations > 0:
                self.validate(mode='source')
            PA, MPA, MIoU, FWIoU = self.validate()

            # Restore current (non-EMA) params for training
            self.ema.restore()

            # Log val results
            self.writer.add_scalar('PA', PA, self.epoch)
            self.writer.add_scalar('MPA', MPA, self.epoch)
            self.writer.add_scalar('MIoU', MIoU, self.epoch)
            self.writer.add_scalar('FWIoU', FWIoU, self.epoch)

            # Save checkpoint if new best model
            self.current_MIoU = MIoU
            is_best = MIoU > self.best_MIou
            if is_best:
                self.best_MIou = MIoU
                self.best_iter = self.iter
                self.logger.info("=> Saving a new best checkpoint...")
                self.logger.info(
                    "=> The best val MIoU is now {:.3f} from iter {}".format(
                        self.best_MIou, self.best_iter))
                self.save_checkpoint('best.pth')
            else:
                self.logger.info("=> The MIoU of val did not improve.")
                self.logger.info(
                    "=> The best val MIoU is still {:.3f} from iter {}".format(
                        self.best_MIou, self.best_iter))
            self.epoch += 1

        # Save final checkpoint
        self.logger.info("=> The best MIou was {:.3f} at iter {}".format(
            self.best_MIou, self.best_iter))
        self.logger.info(
            "=> Saving the final checkpoint to {}".format('final.pth'))
        self.save_checkpoint('final.pth')

    def train_one_epoch(self):

        # Load and reset
        self.model.train()
        self.evaluator.reset()

        # Helper
        def unpack(x):
            return (x[0], x[1]) if isinstance(x, tuple) else (x, None)

        # Training loop
        total = min(len(self.source_dataloader), len(self.target_dataloader))
        for batch_idx, (batch_s, batch_t) in enumerate(
                tqdm(zip(self.source_dataloader, self.target_dataloader),
                     total=total,
                     desc=f"Epoch {self.epoch + 1}")):

            # Learning rate
            self.poly_lr_scheduler(optimizer=self.optimizer)
            self.writer.add_scalar('train/lr',
                                   self.optimizer.param_groups[0]["lr"],
                                   self.iter)

            # Losses
            losses = {}

            ##########################
            # Source supervised loss #
            ##########################
            x, y, _ = batch_s

            if True:  # For VS Code collapsing

                # Data
                x = x.to(self.device)
                y = y.squeeze(dim=1).to(device=self.device,
                                        dtype=torch.long,
                                        non_blocking=True)

                # Fourier mix: source --> target
                if self.cfg.source_fourier:
                    x = fourier_mix(src_images=x,
                                    tgt_images=batch_t[0].to(self.device),
                                    L=self.cfg.fourier_beta)

                # Forward
                pred = self.model(x)
                pred_1, pred_2 = unpack(pred)

                # Loss (source)
                loss_source_1 = self.loss(pred_1, y)
                if self.cfg.aux:
                    loss_source_2 = self.loss(pred_2, y) * self.cfg.lam_aux
                    loss_source = loss_source_1 + loss_source_2
                else:
                    loss_source = loss_source_1

                # Backward
                loss_source.backward()

                # Clean up
                losses['source_main'] = loss_source_1.cpu().item()
                if self.cfg.aux:
                    losses['source_aux'] = loss_source_2.cpu().item()
                del x, y, loss_source, loss_source_1, loss_source_2

            ######################
            # Target Pseudolabel #
            ######################
            x, _, _ = batch_t
            x = x.to(self.device)

            # First step: run non-augmented image though model to get predictions
            with torch.no_grad():

                # Substep 1: forward pass
                pred = self.model(x.to(self.device))
                pred_1, pred_2 = unpack(pred)

                # Substep 2: convert soft predictions to hard predictions
                pred_P_1 = F.softmax(pred_1, dim=1)
                label_1 = torch.argmax(pred_P_1.detach(), dim=1)
                maxpred_1, argpred_1 = torch.max(pred_P_1.detach(), dim=1)
                T = self.cfg.pseudolabel_threshold
                mask_1 = (maxpred_1 > T)
                ignore_tensor = torch.ones(1).to(
                    self.device, dtype=torch.long) * self.ignore_index
                label_1 = torch.where(mask_1, label_1, ignore_tensor)
                if self.cfg.aux:
                    pred_P_2 = F.softmax(pred_2, dim=1)
                    maxpred_2, argpred_2 = torch.max(pred_P_2.detach(), dim=1)
                    pred_c = (pred_P_1 + pred_P_2) / 2
                    maxpred_c, argpred_c = torch.max(pred_c, dim=1)
                    mask = (maxpred_1 > T) | (maxpred_2 > T)
                    label_2 = torch.where(mask, argpred_c, ignore_tensor)

            ############
            # Aug loss #
            ############
            if self.cfg.lam_aug > 0:

                # Second step: augment image and label
                x_aug, y_aug_1 = augment(images=x.cpu(),
                                         labels=label_1.detach().cpu(),
                                         aug=self.aug)
                y_aug_1 = y_aug_1.to(device=self.device, non_blocking=True)
                if self.cfg.aux:
                    _, y_aug_2 = augment(images=x.cpu(),
                                         labels=label_2.detach().cpu(),
                                         aug=self.aug)
                    y_aug_2 = y_aug_2.to(device=self.device, non_blocking=True)

                # Third step: run augmented image through model to get predictions
                pred_aug = self.model(x_aug.to(self.device))
                pred_aug_1, pred_aug_2 = unpack(pred_aug)

                # Fourth step: calculate loss
                loss_aug_1 = self.loss(pred_aug_1, y_aug_1) * \
                    self.cfg.lam_aug
                if self.cfg.aux:
                    loss_aug_2 = self.loss(pred_aug_2, y_aug_2) * \
                        self.cfg.lam_aug * self.cfg.lam_aux
                    loss_aug = loss_aug_1 + loss_aug_2
                else:
                    loss_aug = loss_aug_1

                # Backward
                loss_aug.backward()

                # Clean up
                losses['aug_main'] = loss_aug_1.cpu().item()
                if self.cfg.aux:
                    losses['aug_aux'] = loss_aug_2.cpu().item()
                del pred_aug, pred_aug_1, pred_aug_2, loss_aug, loss_aug_1, loss_aug_2

            ################
            # Fourier Loss #
            ################
            if self.cfg.lam_fourier > 0:

                # Second step: fourier mix
                x_fourier = fourier_mix(src_images=x.to(self.device),
                                        tgt_images=batch_s[0].to(self.device),
                                        L=self.cfg.fourier_beta)

                # Third step: run mixed image through model to get predictions
                pred_fourier = self.model(x_fourier.to(self.device))
                pred_fourier_1, pred_fourier_2 = unpack(pred_fourier)

                # Fourth step: calculate loss
                loss_fourier_1 = self.loss(pred_fourier_1, label_1) * \
                    self.cfg.lam_fourier

                if self.cfg.aux:
                    loss_fourier_2 = self.loss(pred_fourier_2, label_2) * \
                        self.cfg.lam_fourier * self.cfg.lam_aux
                    loss_fourier = loss_fourier_1 + loss_fourier_2
                else:
                    loss_fourier = loss_fourier_1

                # Backward
                loss_fourier.backward()

                # Clean up
                losses['fourier_main'] = loss_fourier_1.cpu().item()
                if self.cfg.aux:
                    losses['fourier_aux'] = loss_fourier_2.cpu().item()
                del pred_fourier, pred_fourier_1, pred_fourier_2, loss_fourier, loss_fourier_1, loss_fourier_2

            ###############
            # CutMix Loss #
            ###############
            if self.cfg.lam_cutmix > 0:

                # Second step: CutMix
                x_cutmix, y_cutmix = cutmix_combine(
                    images_1=x,
                    labels_1=label_1.unsqueeze(dim=1),
                    images_2=batch_s[0].to(self.device),
                    labels_2=batch_s[1].unsqueeze(dim=1).to(self.device,
                                                            dtype=torch.long))
                y_cutmix = y_cutmix.squeeze(dim=1)

                # Third step: run mixed image through model to get predictions
                pred_cutmix = self.model(x_cutmix)
                pred_cutmix_1, pred_cutmix_2 = unpack(pred_cutmix)

                # Fourth step: calculate loss
                loss_cutmix_1 = self.loss(pred_cutmix_1, y_cutmix) * \
                    self.cfg.lam_cutmix
                if self.cfg.aux:
                    loss_cutmix_2 = self.loss(pred_cutmix_2, y_cutmix) * \
                        self.cfg.lam_cutmix * self.cfg.lam_aux
                    loss_cutmix = loss_cutmix_1 + loss_cutmix_2
                else:
                    loss_cutmix = loss_cutmix_1

                # Backward
                loss_cutmix.backward()

                # Clean up
                losses['cutmix_main'] = loss_cutmix_1.cpu().item()
                if self.cfg.aux:
                    losses['cutmix_aux'] = loss_cutmix_2.cpu().item()
                del pred_cutmix, pred_cutmix_1, pred_cutmix_2, loss_cutmix, loss_cutmix_1, loss_cutmix_2

            ###############
            # CutMix Loss #
            ###############

            # Step optimizer if accumulated enough gradients
            self.optimizer.step()
            self.optimizer.zero_grad()

            # Update model EMA parameters each step
            self.ema.update_params()

            # Calculate total loss
            total_loss = sum(losses.values())

            # Log main losses
            for name, loss in losses.items():
                self.writer.add_scalar(f'train/{name}', loss, self.iter)

            # Log
            if batch_idx % 100 == 0:
                log_string = f"[Epoch {self.epoch}]\t"
                log_string += '\t'.join(
                    [f'{n}: {l:.3f}' for n, l in losses.items()])
                self.logger.info(log_string)

            # Increment global iteration counter
            self.iter += 1

            # End training after finishing iterations
            if self.iter > self.cfg.opt.iterations:
                self.continue_training = False
                return

        # After each epoch, update model EMA buffers (i.e. batch norm stats)
        self.ema.update_buffer()

    @torch.no_grad()
    def validate(self, mode='target'):
        """Validate on target"""
        self.logger.info('Validating')
        self.evaluator.reset()
        self.model.eval()

        # Select dataloader
        if mode == 'target':
            val_loader = self.target_val_dataloader
        elif mode == 'source':
            val_loader = self.source_val_dataloader
        else:
            raise NotImplementedError()

        # Loop
        for val_idx, (x, y, id) in enumerate(
                tqdm(val_loader, desc=f"Val Epoch {self.epoch + 1}")):
            if mode == 'source' and val_idx >= self.cfg.data.source_val_iterations:
                break

            # Forward
            x = x.to(self.device)
            y = y.to(device=self.device, dtype=torch.long)
            pred = self.model(x)
            if isinstance(pred, tuple):
                pred = pred[0]

            # Convert to numpy
            label = y.squeeze(dim=1).cpu().numpy()
            argpred = np.argmax(pred.data.cpu().numpy(), axis=1)

            # Add to evaluator
            self.evaluator.add_batch(label, argpred)

        # Tensorboard images
        vis_imgs = 2
        images_inv = inv_preprocess(x.clone().cpu(),
                                    vis_imgs,
                                    numpy_transform=True)
        labels_colors = decode_labels(label, vis_imgs)
        preds_colors = decode_labels(argpred, vis_imgs)
        for index, (img, lab, predc) in enumerate(
                zip(images_inv, labels_colors, preds_colors)):
            self.writer.add_image(str(index) + '/images', img, self.epoch)
            self.writer.add_image(str(index) + '/labels', lab, self.epoch)
            self.writer.add_image(str(index) + '/preds', predc, self.epoch)

        # Calculate and log
        if self.cfg.data.source.kwargs.class_16:
            PA = self.evaluator.Pixel_Accuracy()
            MPA_16, MPA_13 = self.evaluator.Mean_Pixel_Accuracy()
            MIoU_16, MIoU_13 = self.evaluator.Mean_Intersection_over_Union()
            FWIoU_16, FWIoU_13 = self.evaluator.Frequency_Weighted_Intersection_over_Union(
            )
            PC_16, PC_13 = self.evaluator.Mean_Precision()
            self.logger.info(
                'Epoch:{:.3f}, PA:{:.3f}, MPA_16:{:.3f}, MIoU_16:{:.3f}, FWIoU_16:{:.3f}, PC_16:{:.3f}'
                .format(self.epoch, PA, MPA_16, MIoU_16, FWIoU_16, PC_16))
            self.logger.info(
                'Epoch:{:.3f}, PA:{:.3f}, MPA_13:{:.3f}, MIoU_13:{:.3f}, FWIoU_13:{:.3f}, PC_13:{:.3f}'
                .format(self.epoch, PA, MPA_13, MIoU_13, FWIoU_13, PC_13))
            self.writer.add_scalar('PA', PA, self.epoch)
            self.writer.add_scalar('MPA_16', MPA_16, self.epoch)
            self.writer.add_scalar('MIoU_16', MIoU_16, self.epoch)
            self.writer.add_scalar('FWIoU_16', FWIoU_16, self.epoch)
            self.writer.add_scalar('MPA_13', MPA_13, self.epoch)
            self.writer.add_scalar('MIoU_13', MIoU_13, self.epoch)
            self.writer.add_scalar('FWIoU_13', FWIoU_13, self.epoch)
            PA, MPA, MIoU, FWIoU = PA, MPA_13, MIoU_13, FWIoU_13
        else:
            PA = self.evaluator.Pixel_Accuracy()
            MPA = self.evaluator.Mean_Pixel_Accuracy()
            MIoU = self.evaluator.Mean_Intersection_over_Union()
            FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
            PC = self.evaluator.Mean_Precision()
            self.logger.info(
                'Epoch:{:.3f}, PA1:{:.3f}, MPA1:{:.3f}, MIoU1:{:.3f}, FWIoU1:{:.3f}, PC:{:.3f}'
                .format(self.epoch, PA, MPA, MIoU, FWIoU, PC))
            self.writer.add_scalar('PA', PA, self.epoch)
            self.writer.add_scalar('MPA', MPA, self.epoch)
            self.writer.add_scalar('MIoU', MIoU, self.epoch)
            self.writer.add_scalar('FWIoU', FWIoU, self.epoch)

        return PA, MPA, MIoU, FWIoU

    def save_checkpoint(self, filename='checkpoint.pth'):
        torch.save(
            {
                'epoch': self.epoch + 1,
                'iter': self.iter,
                'state_dict': self.ema.model.state_dict(),
                'shadow': self.ema.shadow,
                'optimizer': self.optimizer.state_dict(),
                'best_MIou': self.best_MIou
            }, filename)

    def load_checkpoint(self, filename):
        checkpoint = torch.load(filename, map_location='cpu')

        # Get model state dict
        if not self.cfg.train and 'shadow' in checkpoint:
            state_dict = checkpoint['shadow']
        elif 'state_dict' in checkpoint:
            state_dict = checkpoint['state_dict']
        else:
            state_dict = checkpoint

        # Remove DP/DDP if it exists
        state_dict = {
            k.replace('module.', ''): v
            for k, v in state_dict.items()
        }

        # Load state dict
        if hasattr(self.model, 'module'):
            self.model.module.load_state_dict(state_dict)
        else:
            self.model.load_state_dict(state_dict)
        self.logger.info(f"Model loaded successfully from {filename}")

        # Load optimizer and epoch
        if self.cfg.train and self.cfg.model.resume_from_checkpoint:
            if 'optimizer' in checkpoint:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
                self.logger.info(
                    f"Optimizer loaded successfully from {filename}")
            if 'epoch' in checkpoint and 'iter' in checkpoint:
                self.epoch = checkpoint['epoch']
                self.iter = checkpoint[
                    'iter'] if 'iter' in checkpoint else checkpoint['iteration']
                self.logger.info(
                    f"Resuming training from epoch {self.epoch} iter {self.iter}"
                )
        else:
            self.logger.info(f"Did not resume optimizer")

    def poly_lr_scheduler(self,
                          optimizer,
                          init_lr=None,
                          iter=None,
                          max_iter=None,
                          power=None):
        init_lr = self.cfg.opt.lr if init_lr is None else init_lr
        iter = self.iter if iter is None else iter
        max_iter = self.cfg.opt.iterations if max_iter is None else max_iter
        power = self.cfg.opt.poly_power if power is None else power
        new_lr = init_lr * (1 - float(iter) / max_iter)**power
        optimizer.param_groups[0]["lr"] = new_lr
        if len(optimizer.param_groups) == 2:
            optimizer.param_groups[1]["lr"] = 10 * new_lr