Ejemplo n.º 1
0
def eval_maml(model, criterion,
          valloader, device, epoch, 
          log_interval, writer, args):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    avg_loss = AverageMeter()
    avg_acc = AverageMeter()
    # Create recorder
    averagers = [avg_loss, avg_acc]
    names = ['val loss','val acc']
    recoder = Recorder(averagers,names,writer,batch_time,data_time)
    # Set evaluation mode
    model.eval()

    recoder.tik()
    recoder.data_tik()
    # Settings
    create_graph = (True if args.order == 2 else False)
    for i, batch in enumerate(valloader):
        # measure data loading time
        recoder.data_tok()

        # get the inputs and labels
        data, lab = [_.to(device) for _ in batch]

        # forward
        # data = data.view( ((args.shot+args.query),args.train_way) + data.size()[-3:] )
        # data = data.permute(1,0,2,3,4).contiguous()
        # data = data.view( (-1,) + data.size()[-3:] )
        p = args.shot * args.test_way
        data_shot = data[:p]
        data_query = data[p:]
        data_shape = data.size()[-3:]


        # Create a fast model using the current meta model weights
        fast_weights = OrderedDict(model.named_parameters())

        # Train the model for `inner_train_steps` iterations
        for inner_batch in range(args.inner_train_steps):
            # Perform update of model weights
            y = create_nshot_task_label(args.test_way, args.shot).to(device)
            logits = model.functional_forward(data_shot, fast_weights)
            loss = criterion(logits, y)
            gradients = torch.autograd.grad(loss, fast_weights.values(), create_graph=create_graph)

            # Update weights manually
            fast_weights = OrderedDict(
                (name, param - args.inner_lr * grad)
                for ((name, param), grad) in zip(fast_weights.items(), gradients)
            )
        
        # Do a pass of the model on the validation data from the current task
        y = create_nshot_task_label(args.test_way, args.query_val).to(device)
        logits = model.functional_forward(data_query, fast_weights)
        loss = criterion(logits, y)
        loss.backward(retain_graph=True)

        # Get post-update accuracies
        y_pred = logits.softmax(-1)
        acc = accuracy(y_pred, y)[0]

        # measure elapsed time
        recoder.tok()
        recoder.tik()
        recoder.data_tik()

        # update average value
        vals = [loss.item(),acc]
        recoder.update(vals)

        if i % log_interval == log_interval-1:
            recoder.log(epoch,i,len(valloader),mode='Eval')

    return recoder.get_avg('val acc')
Ejemplo n.º 2
0
def evaluate_confusion_matrix(model, criterion,
          valloader, device, epoch, 
          log_interval, writer, args, relation, name,
          category_space='novel',
          base_class=400,novel_class=100):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    avg_loss = AverageMeter()
    avg_acc = AverageMeter()
    # Create recorder
    averagers = [avg_loss, avg_acc]
    names = ['val loss', 'val acc']
    recoder = Recorder(averagers,names,writer,batch_time,data_time)
    # Set evaluation mode
    model.eval()

    recoder.tik()
    recoder.data_tik()
    num_class = base_class + novel_class
    cmat = numpy.zeros([num_class,num_class])
    for i, batch in enumerate(valloader):
        with torch.no_grad():
            # measure data loading time
            recoder.data_tok()

            # get the inputs and labels
            data, lab = [_.to(device) for _ in batch]

            # forward
            proto = model.baseModel(data)
            global_set = torch.cat([model.global_base,model.global_novel])
            logits = relation(proto,global_set)
            # print('logits: ',logits.argmax(-1))
            # print('lab: ',lab)
            
            # compute the loss
            loss = criterion(logits, lab)

            # compute the metrics
            acc = accuracy(logits, lab)[0]

            # compute the confusion matrix
            predict = logits.argmax(-1)
            for p,g in zip(predict,lab):
                cmat[g,p] += 1
            
            # measure elapsed time
            recoder.tok()
            recoder.tik()
            recoder.data_tik()

        # update average value
        vals = [loss.item(),acc]
        recoder.update(vals)

        if i % log_interval == log_interval-1:
            recoder.log(epoch,i,len(valloader),mode='Test')

    # normalize & print confusion matrix
    cmat = cmat / cmat.sum(1)
    if category_space == 'novel':
        cmat = cmat[base_class:,base_class:]
    elif category_space == 'base':
        cmat = cmat[:base_class,:base_class]
    df = pd.DataFrame(cmat)
    df.to_csv(name)

    return recoder.get_avg('val acc')
Ejemplo n.º 3
0
def main():
    global args, logger, writer, dataset_configs
    global best_top1_epoch, best_top5_epoch, best_top1, best_top5, best_top1_top5, best_top5_top1
    dataset_configs = get_and_save_args(parser)
    parser.set_defaults(**dataset_configs)
    args = parser.parse_args()

    # ================== GPU setting ===============
    os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    """copy codes and creat dir for saving models and logs"""
    if not os.path.isdir(args.snapshot_pref):
        os.makedirs(args.snapshot_pref)

    logger = Prepare_logger(args)
    logger.info('\ncreating folder: ' + args.snapshot_pref)

    if not args.evaluate:
        writer = SummaryWriter(args.snapshot_pref)
        recorder = Recorder(args.snapshot_pref)
        recorder.writeopt(args)

    logger.info('\nruntime args\n\n{}\n'.format(json.dumps(vars(args), indent=4)))

    """prepare dataset and model"""
    # word2idx = json.load(open('./data/dataset/TACoS/TACoS_word2id_glove_lower.json', 'r'))
    # train_dataset = TACoS(args, split='train')
    # test_dataset = TACoS(args, split='test')
    word2idx = json.load(open('./data/dataset/Charades/Charades_word2id.json', 'r'))
    train_dataset = CharadesSTA(args, split='train')
    test_dataset = CharadesSTA(args, split='test')
    train_dataloader = DataLoader(
        train_dataset, batch_size=args.batch_size,
        shuffle=True, collate_fn=collate_data, num_workers=8, pin_memory=True
    )
    test_dataloader = DataLoader(
        test_dataset, batch_size=args.test_batch_size,
        shuffle=False, collate_fn=collate_data, num_workers=8, pin_memory=True
    )
    vocab_size = len(word2idx)

    lr = args.lr
    n_epoch = args.n_epoch

    main_model = mainModel(vocab_size, args, hidden_dim=512, embed_dim=300,
                           bidirection=True, graph_node_features=1024)

    if os.path.exists(args.glove_weights):
        logger.info("Loading glove weights")
        main_model.query_encoder.embedding.weight.data.copy_(torch.load(args.glove_weights))
    else:
        logger.info("Generating glove weights")
        main_model.query_encoder.embedding.weight.data.copy_(glove_init(word2idx))

    main_model = nn.DataParallel(main_model).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            logger.info(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            pretrained_dict = checkpoint['state_dict']
            # only resume part of model paramete
            model_dict = main_model.state_dict()
            pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
            model_dict.update(pretrained_dict)
            main_model.load_state_dict(model_dict)
            # main_model.load_state_dict(checkpoint['state_dict'])
            logger.info(("=> loaded checkpoint '{}' (epoch {})"
                      .format(args.evaluate, checkpoint['epoch'])))
        else:
            logger.info(("=> no checkpoint found at '{}'".format(args.resume)))

    if args.evaluate:
        topks, accuracy_topks = evaluate(main_model, test_dataloader, word2idx, False)
        for ind, topk in enumerate(topks):
            print("R@{}: {:.1f}\n".format(topk, accuracy_topks[ind] * 100))
        return

    learned_params = None
    if args.is_first_stage:
        for name, value in main_model.named_parameters():
            if 'iou_scores' in name or 'mix_fc' in name:
                value.requires_grad = False
        learned_params = filter(lambda p: p.requires_grad, main_model.parameters())
        n_epoch = 10
    elif args.is_second_stage:
        head_params = main_model.module.fcos.head.iou_scores.parameters()
        fc_params = main_model.module.fcos.head.mix_fc.parameters()
        learned_params = list(head_params) + list(fc_params)
        lr /= 100
    elif args.is_third_stage:
        learned_params = main_model.parameters()
        lr /= 10000

    optimizer = torch.optim.Adam(learned_params, lr)

    for epoch in range(args.start_epoch, n_epoch):

        train_loss = train_epoch(main_model, train_dataloader, optimizer, epoch)

        if (epoch + 1) % args.eval_freq == 0 or epoch == args.n_epoch - 1:

            val_loss, topks, accuracy_topks = validate_epoch(
                main_model, test_dataloader, epoch, word2idx, False
            )

            for ind, topk in enumerate(topks):
                writer.add_scalar('test_result/Recall@top{}'.format(topk), accuracy_topks[ind]*100, epoch)

            is_best_top1 = (accuracy_topks[0]*100) > best_top1
            best_top1 = max((accuracy_topks[0]*100), best_top1)
            if is_best_top1:
                best_top1_epoch = epoch
                best_top1_top5 = accuracy_topks[1]*100
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': main_model.state_dict(),
                'loss': val_loss,
                'top1': accuracy_topks[0]*100,
                'top5': accuracy_topks[1]*100,
            }, is_best_top1, epoch=epoch, top1=accuracy_topks[0]*100, top5=accuracy_topks[1]*100)

            is_best_top5 = (accuracy_topks[1]*100) > best_top5
            best_top5= max((accuracy_topks[1]*100), best_top5)
            if is_best_top5:
                best_top5_epoch = epoch
                best_top5_top1= accuracy_topks[0] * 100
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': main_model.state_dict(),
                'loss': val_loss,
                'top1': accuracy_topks[0]*100,
                'top5': accuracy_topks[1]*100,
            }, is_best_top5, epoch=epoch, top1=accuracy_topks[0]*100, top5=accuracy_topks[1]*100)

            writer.add_scalar('test_result/Best_Recall@top1', best_top1, epoch)
            writer.add_scalar('test_result/Best_Recall@top5', best_top5, epoch)

            logger.info(
                "R@1: {:.2f}, R@5: {:.2f}, epoch: {}\n".format(
                    accuracy_topks[0] * 100, accuracy_topks[1] * 100, epoch)
            )
            logger.info(
                "Current best top1: R@1: {:.2f}, R@5: {:.2f}, epoch: {} \n".format(
                    best_top1, best_top1_top5, best_top1_epoch)
            )
            logger.info(
                "Current best top5: R@1: {:.2f}, R@5: {:.2f}, epoch: {} \n".format(
                    best_top5_top1, best_top5, best_top5_epoch)
            )
Ejemplo n.º 4
0
def train_one_epoch(model, trainloader, device, epoch, log_interval, writer):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    avg_ltotal = AverageMeter()
    avg_lrec = AverageMeter()
    avg_lltc = AverageMeter()
    avg_lee = AverageMeter()
    avg_ladv_g = AverageMeter()
    avg_ladv_d = AverageMeter()
    # Set trainning mode
    model.train()
    # Create recorder
    averagers = [
        avg_ltotal, avg_lrec, avg_lltc, avg_lee, avg_ladv_g, avg_ladv_d
    ]
    names = [
        'train Ltotal', 'train Lrec', 'train Lltc', 'train Lee',
        'train Ladv G', 'train Ladv D'
    ]
    recoder = Recorder(averagers, names, writer, batch_time, data_time)

    recoder.tik()
    recoder.data_tik()
    for i, data in enumerate(trainloader):
        # measure data loading time
        recoder.data_tok()

        # get the inputs
        Qh, Ph, glove_angles, group_names = data
        Qh, Ph, glove_angles = [x.to(device) for x in (Qh, Ph, glove_angles)]

        # optimize parameters
        losses = model.optimize_parameters(Qh, Ph)
        Ltotal, Lrec, Lltc, Lee, Ladv_G, Ladv_D = losses

        # measure elapsed time
        recoder.tok()
        recoder.tik()
        recoder.data_tik()

        # update average value
        vals = [
            Ltotal.item(),
            Lrec.item(),
            Lltc.item(),
            Lee.item(),
            Ladv_G.item(),
            Ladv_D.item()
        ]
        N = Qh.size(0)
        recoder.update(vals, count=N)

        if i == 0 or i % log_interval == log_interval - 1:
            recoder.log(epoch, i, len(trainloader))
            # Reset average meters
            recoder.reset()
Ejemplo n.º 5
0
def test_seq2seq(model, criterion, dataloader, device, epoch, log_interval,
                 writer):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    avg_loss = AverageMeter()
    avg_acc = AverageMeter()
    avg_wer = AverageMeter()
    # Create recorder
    averagers = [avg_loss, avg_acc, avg_wer]
    names = ['test loss', 'test acc', 'test wer']
    recoder = Recorder(averagers, names, writer, batch_time, data_time)
    # Set evaluation mode
    model.eval()

    recoder.tik()
    recoder.data_tik()
    with torch.no_grad():
        for batch_idx, batch in enumerate(dataloader):
            # measure data loading time
            recoder.data_tok()
            # get the data and labels
            imgs = batch['videos'].cuda()
            target = batch['annotations'].permute(1, 0).contiguous().cuda()

            # forward(no teacher forcing)
            outputs = model(imgs, target, 0)

            # target: (batch_size, trg len)
            # outputs: (trg_len, batch_size, output_dim)
            # skip sos
            output_dim = outputs.shape[-1]
            outputs = outputs[1:].view(-1, output_dim)
            target = target.permute(1, 0)[1:].reshape(-1)

            # compute the loss
            loss = criterion(outputs, target)

            # compute the accuracy
            prediction = torch.max(outputs, 1)[1]
            score = accuracy_score(target.cpu().data.squeeze().numpy(),
                                   prediction.cpu().data.squeeze().numpy())

            # compute wer
            # prediction: ((trg_len-1)*batch_size)
            # target: ((trg_len-1)*batch_size)
            batch_size = imgs.shape[0]
            prediction = prediction.view(-1, batch_size).permute(1, 0).tolist()
            target = target.view(-1, batch_size).permute(1, 0).tolist()
            wers = []
            for i in range(batch_size):
                # add mask(remove padding, eos, sos)
                prediction[i] = [
                    item for item in prediction[i] if item not in [0, 1, 2]
                ]
                target[i] = [
                    item for item in target[i] if item not in [0, 1, 2]
                ]
                wers.append(wer(target[i], prediction[i]))
            batch_wer = sum(wers) / len(wers)

            # measure elapsed time
            recoder.tok()
            recoder.tik()
            recoder.data_tik()

            # update average value
            vals = [loss.item(), score, batch_wer]
            b = imgs.size(0)
            recoder.update(vals, count=b)

            # logging
            if batch_idx == 0 or batch_idx % log_interval == log_interval - 1 or batch_idx == len(
                    dataloader) - 1:
                recoder.log(epoch, batch_idx, len(dataloader), mode='Test')

    return recoder.get_avg('val acc')
Ejemplo n.º 6
0
def train_one_epoch(model, criterion, optimizer, trainloader, device, epoch,
                    log_interval, writer):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    avg_loss = AverageMeter()
    # Set trainning mode
    model.train()
    # Create recorder
    averagers = [avg_loss]
    names = ['train loss']
    recoder = Recorder(averagers, names, writer, batch_time, data_time)

    recoder.tik()
    recoder.data_tik()
    for i, data in enumerate(trainloader):
        # measure data loading time
        recoder.data_tok()

        # get the inputs
        q, p = [x.to(device) for x in data]

        optimizer.zero_grad()
        # forward
        outputs = model(q)

        # compute the loss
        loss = criterion(outputs, q)
        # backward & optimize
        loss.backward()

        optimizer.step()

        # measure elapsed time
        recoder.tok()
        recoder.tik()
        recoder.data_tik()

        # update average value
        vals = [loss.item()]
        N = q.size(0)
        recoder.update(vals, count=N)

        if i == 0 or i % log_interval == log_interval - 1:
            recoder.log(epoch, i, len(trainloader))
            # Reset average meters
            recoder.reset()
Ejemplo n.º 7
0
def train_seq2seq(model, criterion, optimizer, clip, dataloader, device, epoch,
                  log_interval, writer):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    avg_loss = AverageMeter()
    avg_acc = AverageMeter()
    avg_wer = AverageMeter()
    # Create recorder
    averagers = [avg_loss, avg_acc, avg_wer]
    names = ['train loss', 'train acc', 'train wer']
    recoder = Recorder(averagers, names, writer, batch_time, data_time)
    # Set trainning mode
    model.train()

    recoder.tik()
    recoder.data_tik()
    for batch_idx, batch in enumerate(dataloader):
        # measure data loading time
        recoder.data_tok()
        # get the data and labels
        imgs = batch['videos'].cuda()
        target = batch['annotations'].permute(1, 0).contiguous().cuda()

        optimizer.zero_grad()
        # forward
        outputs = model(imgs, target)

        # target: (batch_size, trg len)
        # outputs: (trg_len, batch_size, output_dim)
        # skip sos
        output_dim = outputs.shape[-1]
        outputs = outputs[1:].view(-1, output_dim)
        target = target.permute(1, 0)[1:].reshape(-1)

        # compute the loss
        loss = criterion(outputs, target)

        # compute the accuracy
        prediction = torch.max(outputs, 1)[1]
        score = accuracy_score(target.cpu().data.squeeze().numpy(),
                               prediction.cpu().data.squeeze().numpy())

        # compute wer
        # prediction: ((trg_len-1)*batch_size)
        # target: ((trg_len-1)*batch_size)
        batch_size = imgs.shape[0]
        prediction = prediction.view(-1, batch_size).permute(1, 0).tolist()
        target = target.view(-1, batch_size).permute(1, 0).tolist()
        wers = []
        for i in range(batch_size):
            # add mask(remove padding, sos, eos)
            prediction[i] = [
                item for item in prediction[i] if item not in [0, 1, 2]
            ]
            target[i] = [item for item in target[i] if item not in [0, 1, 2]]
            wers.append(wer(target[i], prediction[i]))
        batch_wer = sum(wers) / len(wers)

        # backward & optimize
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        # measure elapsed time
        recoder.tok()
        recoder.tik()
        recoder.data_tik()

        # update average value
        vals = [loss.item(), score, batch_wer]
        b = imgs.size(0)
        recoder.update(vals, count=b)

        if batch_idx == 0 or (batch_idx + 1) % log_interval == 0:
            recoder.log(epoch, batch_idx, len(dataloader))
            # Reset average meters
            recoder.reset()
Ejemplo n.º 8
0
def train_hcn_lstm(model, criterion, optimizer, trainloader, device, epoch,
                   log_interval, writer, reverse_dict, clip_g):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    avg_wer = AverageMeter()
    avg_bleu = AverageMeter()
    # Set trainning mode
    model.train()
    # Create recorder
    averagers = [avg_loss, avg_acc, avg_wer]
    names = ['train loss', 'train acc', 'train wer']
    recoder = Recorder(averagers, names, writer, batch_time, data_time)

    end = time.time()
    for i, data in enumerate(trainloader):
        # measure data loading time
        data_time.update(time.time() - end)

        # get the inputs and labels
        # shape of tgt is N x T
        input, tgt = data['src'].to(device), data['tgt'].to(device)
        src_len_list, tgt_len_list = data['src_len_list'].to(
            device), data['tgt_len_list'].to(device)

        optimizer.zero_grad()
        # forward
        outputs = model(input, src_len_list)
        # print(outputs.argmax(2).permute(1,0))
        # print(tgt)

        # compute the loss
        # tgt = pack_padded_sequence(tgt,tgt_len_list)
        loss = criterion(outputs, tgt, src_len_list, tgt_len_list)

        # backward & optimize
        loss.backward()

        optimizer.step()

        # compute the metrics
        wer = count_wer(outputs, tgt)
        bleu = count_bleu(outputs, tgt.permute(1, 0), reverse_dict)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # update average value
        N = tgt.size(0)
        losses.update(loss.item(), N)
        avg_wer.update(wer, N)
        avg_bleu.update(bleu, N)

        if i == 0 or i % log_interval == log_interval - 1:
            info = ('Epoch: [{0}][{1}/{2}]\t'
                    'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t'
                    'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t'
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                    'Wer {wer.val:.5f} ({wer.avg:.5f})\t'
                    'Bleu {bleu.val:.5f} ({bleu.avg:.5f})\t'.format(
                        epoch,
                        i,
                        len(trainloader),
                        batch_time=batch_time,
                        data_time=data_time,
                        loss=losses,
                        wer=avg_wer,
                        bleu=avg_bleu,
                        lr=optimizer.param_groups[-1]['lr']))
            print(info)
            writer.add_scalar('train loss', losses.avg,
                              epoch * len(trainloader) + i)
            writer.add_scalar('train wer', avg_wer.avg,
                              epoch * len(trainloader) + i)
            writer.add_scalar('train bleu', avg_bleu.avg,
                              epoch * len(trainloader) + i)
            # Reset average meters
            losses.reset()
            avg_wer.reset()
            avg_bleu.reset()
Ejemplo n.º 9
0
def train_maml(model, criterion, optimizer, trainloader, device, epoch,
               log_interval, writer, args):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    avg_loss = AverageMeter()
    avg_acc = AverageMeter()
    # Create recorder
    averagers = [avg_loss, avg_acc]
    names = ['train loss', 'train acc']
    recoder = Recorder(averagers, names, writer, batch_time, data_time)
    # Set trainning mode
    model.train()

    recoder.tik()
    recoder.data_tik()
    # Settings
    create_graph = (True if args.order == 2 else False)
    task_gradients = []
    task_losses = []
    for i, batch in enumerate(trainloader):
        # measure data loading time
        recoder.data_tok()

        # get the inputs and labels
        data, lab = [_.to(device) for _ in batch]

        # forward
        # data = data.view( ((args.shot+args.query),args.train_way) + data.size()[-3:] )
        # data = data.permute(1,0,2,3,4).contiguous()
        # data = data.view( (-1,) + data.size()[-3:] )
        p = args.shot * args.train_way
        data_shot = data[:p]
        data_query = data[p:]
        data_shape = data_shot.size()[-3:]

        # Create a fast model using the current meta model weights
        fast_weights = OrderedDict(model.named_parameters())

        # Train the model for `inner_train_steps` iterations
        for inner_batch in range(args.inner_train_steps):
            # Perform update of model weights
            y = create_nshot_task_label(args.train_way, args.shot).to(device)
            logits = model.functional_forward(data_shot, fast_weights)
            loss = criterion(logits, y)
            gradients = torch.autograd.grad(loss,
                                            fast_weights.values(),
                                            create_graph=create_graph)

            # Update weights manually
            fast_weights = OrderedDict(
                (name, param - args.inner_lr * grad)
                for ((name, param),
                     grad) in zip(fast_weights.items(), gradients))

        # Do a pass of the model on the validation data from the current task
        y = create_nshot_task_label(args.train_way, args.query).to(device)
        logits = model.functional_forward(data_query, fast_weights)
        loss = criterion(logits, y)
        loss.backward(retain_graph=True)

        # Get post-update accuracies
        y_pred = logits.softmax(-1)
        acc = accuracy(y_pred, y)[0]

        # Accumulate losses and gradients
        task_losses.append(loss)
        gradients = torch.autograd.grad(loss,
                                        fast_weights.values(),
                                        create_graph=create_graph)
        named_grads = {
            name: g
            for ((name, _), g) in zip(fast_weights.items(), gradients)
        }
        task_gradients.append(named_grads)

        # measure elapsed time
        recoder.tok()
        recoder.tik()
        recoder.data_tik()

        # update average value
        vals = [loss.item(), acc]
        recoder.update(vals)

        if i % log_interval == log_interval - 1:
            recoder.log(epoch, i, len(trainloader))
            # Reset average meters
            recoder.reset()

    if args.order == 1:
        sum_task_gradients = {
            k: torch.stack([grad[k] for grad in task_gradients]).mean(dim=0)
            for k in task_gradients[0].keys()
        }
        hooks = []
        for name, param in model.named_parameters():
            hooks.append(
                param.register_hook(replace_grad(sum_task_gradients, name)))

        model.train()
        optimizer.zero_grad()
        # Dummy pass in order to create `loss` variable
        # Replace dummy gradients with mean task gradients using hooks
        logits = model(
            torch.zeros((args.train_way, ) + data_shape).to(device,
                                                            dtype=torch.float))
        loss = criterion(logits,
                         create_nshot_task_label(args.train_way, 1).to(device))
        loss.backward()
        optimizer.step()

        for h in hooks:
            h.remove()

    elif args.order == 2:
        model.train()
        optimizer.zero_grad()
        meta_batch_loss = torch.stack(task_losses).mean()
        meta_batch_loss.backward()
        optimizer.step()
Ejemplo n.º 10
0
def train_gcr(model, criterion, optimizer, optimizer_cnn, trainloader, device,
              epoch, log_interval, writer, args):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    avg_loss1 = AverageMeter()
    avg_loss2 = AverageMeter()
    avg_acc1 = AverageMeter()
    avg_acc2 = AverageMeter()
    # Create recorder
    averagers = [avg_loss1, avg_loss2, avg_acc1, avg_acc2]
    names = ['train loss1', 'train loss2', 'train acc1', 'train acc2']
    recoder = Recorder(averagers, names, writer, batch_time, data_time)
    # Set trainning mode
    model.train()

    recoder.tik()
    recoder.data_tik()
    for i, batch in enumerate(trainloader):
        # measure data loading time
        recoder.data_tok()

        # get the inputs and labels
        data, lab = [_.to(device) for _ in batch]

        # forward
        p = args.shot * args.train_way
        data_shot = data[:p]
        data_query = data[p:]

        logits, label, logits2, gt = \
                model(data_shot,data_query,lab)
        # compute the loss
        loss, loss1, loss2 = criterion(logits, label, logits2, gt)

        # backward & optimize
        optimizer.zero_grad()
        optimizer_cnn.zero_grad()
        loss.backward()
        if epoch > 45:
            optimizer_cnn.step()
        optimizer.step()

        # compute the metrics
        acc1 = accuracy(logits, label)[0]
        acc2 = accuracy(logits2, gt)[0]

        # measure elapsed time
        recoder.tok()
        recoder.tik()
        recoder.data_tik()

        # update average value
        vals = [loss1.item(), loss2.item(), acc1, acc2]
        recoder.update(vals)

        if i % log_interval == log_interval - 1:
            recoder.log(epoch, i, len(trainloader))
            # Reset average meters
            recoder.reset()
Ejemplo n.º 11
0
def train_mn_pn(model, criterion, optimizer, trainloader, device, epoch,
                log_interval, writer, args):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    avg_loss = AverageMeter()
    avg_acc = AverageMeter()
    # Create recorder
    averagers = [avg_loss, avg_acc]
    names = ['train loss', 'train acc']
    recoder = Recorder(averagers, names, writer, batch_time, data_time)
    # Set trainning mode
    model.train()

    recoder.tik()
    recoder.data_tik()
    for i, batch in enumerate(trainloader):
        # measure data loading time
        recoder.data_tok()

        # get the inputs and labels
        data, lab = [_.to(device) for _ in batch]

        # forward
        p = args.shot * args.train_way
        data_shot = data[:p]
        data_query = data[p:]

        y_pred, label = model(data_shot, data_query)
        # print('lab: {}'.format(lab.view((args.shot+args.query),args.train_way)[0]))
        # compute the loss
        loss = criterion(y_pred, label)
        # print('y_pred: {}'.format(y_pred))
        # print('label: {}'.format(label))

        # backward & optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # compute the metrics
        acc = accuracy(y_pred, label)[0]

        # measure elapsed time
        recoder.tok()
        recoder.tik()
        recoder.data_tik()

        # update average value
        vals = [loss.item(), acc]
        recoder.update(vals)

        if i % log_interval == log_interval - 1:
            recoder.log(epoch, i, len(trainloader))
            # Reset average meters
            recoder.reset()
Ejemplo n.º 12
0
def train_cnn(model, criterion, optimizer, trainloader, device, epoch,
              log_interval, writer, args):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    avg_acc = AverageMeter()
    global_proto = numpy.zeros([args.num_class, args.feature_dim])
    # Create recorder
    averagers = [losses, avg_acc]
    names = ['train loss', 'train acc']
    recoder = Recorder(averagers, names, writer, batch_time, data_time)
    # Set trainning mode
    model.train()

    recoder.tik()
    recoder.data_tik()
    for i, batch in enumerate(trainloader, 1):
        # measure data loading time
        recoder.data_tok()

        # get the data and labels
        data, lab = [_.to(device) for _ in batch]

        optimizer.zero_grad()
        # forward
        outputs = model(data)

        # compute the loss
        loss = criterion(outputs, lab)

        # backward & optimize
        loss.backward()
        optimizer.step()

        # Account global proto
        proto = model.get_feature(data)
        for idx, p in enumerate(proto):
            p = p.data.detach().cpu().numpy()
            c = lab[idx]
            global_proto[c] += p
        # compute the metrics
        acc = accuracy(outputs, lab)[0]

        # measure elapsed time
        recoder.tok()
        recoder.tik()
        recoder.data_tik()

        # update average value
        vals = [loss.item(), acc]
        recoder.update(vals)

        # logging
        if i == 0 or i % log_interval == log_interval - 1:
            recoder.log(epoch, i, len(trainloader))
            # Reset average meters
            recoder.reset()

    global_proto[:args.n_base] = global_proto[:args.n_base] / args.n_reserve
    global_proto[args.n_base:] = global_proto[args.n_base:] / args.shot
    return global_proto
Ejemplo n.º 13
0
def test_one_epoch(model, testloader, device, epoch, log_interval, writer,
                   h5writer):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    avg_ltotal = AverageMeter()
    avg_lrec = AverageMeter()
    avg_lltc = AverageMeter()
    avg_lee = AverageMeter()
    avg_ladv_g = AverageMeter()
    avg_ladv_d = AverageMeter()
    # Set eval mode
    model.eval()
    # Create recorder
    averagers = [
        avg_ltotal, avg_lrec, avg_lltc, avg_lee, avg_ladv_g, avg_ladv_d
    ]
    names = [
        'test Ltotal', 'test Lrec', 'test Lltc', 'test Lee', 'test Ladv G',
        'test Ladv D'
    ]
    recoder = Recorder(averagers, names, writer, batch_time, data_time)

    recoder.tik()
    recoder.data_tik()
    with torch.no_grad():
        for i, data in enumerate(testloader):
            # measure data loading time
            recoder.data_tok()

            # get the inputs
            Qh, Ph, glove_angles, group_names = data
            Qh, Ph, glove_angles = [
                x.to(device) for x in (Qh, Ph, glove_angles)
            ]

            # forward
            outputs = model(Qh, Ph)
            losses = model.calculate_losses(outputs)
            Qh, Ph, Qh_hat, Qh_ew, Ph_ew, Jr, Qr_ew, Pr_ew, Phi_h, Phi_r, Real_pred, Fake_pred = outputs
            Ltotal, Lrec, Lltc, Lee, Ladv_G, Ladv_D = losses

            # save file
            for group_name, joint_angle, glove_angle in zip(
                    group_names, Jr, glove_angles):
                h5writer.write(group_name, joint_angle, glove_angle)

            # measure elapsed time
            recoder.tok()
            recoder.tik()
            recoder.data_tik()

            # update average value
            vals = [
                Ltotal.item(),
                Lrec.item(),
                Lltc.item(),
                Lee.item(),
                Ladv_G.item(),
                Ladv_D.item()
            ]
            N = Qh.size(0)
            recoder.update(vals, count=N)

            if i == 0 or i % log_interval == log_interval - 1:
                recoder.log(epoch, i, len(testloader), mode='Test')

    return avg_ltotal.avg
Ejemplo n.º 14
0
def eval_gcr(model, criterion,
          valloader, device, epoch, 
          log_interval, writer, args):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    avg_loss1 = AverageMeter()
    avg_loss2 = AverageMeter()
    avg_acc1 = AverageMeter()
    avg_acc2 = AverageMeter()
    statistic = []
    # Create recorder
    averagers = [avg_loss1, avg_loss2, avg_acc1, avg_acc2]
    names = ['val loss1','val loss2','val acc1','val acc2']
    recoder = Recorder(averagers,names,writer,batch_time,data_time)
    # Set evaluation mode
    model.eval()

    recoder.tik()
    recoder.data_tik()
    for i, batch in enumerate(valloader):
        with torch.no_grad():
            # measure data loading time
            recoder.data_tok()

            # get the inputs and labels
            data, lab = [_.to(device) for _ in batch]

            # forward
            p = args.shot * args.test_way
            data_shot = data[:p]
            data_query = data[p:]

            logits, label, logits2, gt = \
                    model(data_shot,data_query,lab,mode='eval')
            # compute the loss
            loss, loss1, loss2 = criterion(logits, label, logits2, gt)
            # print('logits: {}'.format(logits))
            # print('out: {}'.format(logits.argmax(-1)))
            # print('label: {}'.format(label))

            # compute the metrics
            acc1 = accuracy(logits, label)[0]
            acc2 = accuracy(logits2, gt)[0]

            # measure elapsed time
            recoder.tok()
            recoder.tik()
            recoder.data_tik()

        # update average value & account statistic
        vals = [loss1.item(),loss2.item(),acc1,acc2]
        recoder.update(vals)
        statistic.append(acc1.data.cpu().numpy())

        if i % log_interval == log_interval-1:
            recoder.log(epoch,i,len(valloader),mode='Eval')

    return recoder.get_avg('val acc1'), numpy.array(statistic)
Ejemplo n.º 15
0
def test_text2sign(model, criterion, testloader, device, epoch, log_interval,
                   writer):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    avg_loss = AverageMeter()
    # Set eval mode
    model.eval()
    # Create recorder
    averagers = [avg_loss]
    names = ['test loss']
    recoder = Recorder(averagers, names, writer, batch_time, data_time)

    recoder.tik()
    recoder.data_tik()
    with torch.no_grad():
        for i, data in enumerate(testloader):
            # measure data loading time
            recoder.data_tok()

            # get the inputs and labels
            # shape of input is N x T
            # shape of tgt is N x T2 x J x D
            input, tgt = data['input'].to(device), data['tgt'].to(device)

            # forward
            outputs = model(input, tgt, 0)

            # compute the loss
            loss = criterion(outputs, tgt[:, 1:, :, :])
            # out_path = './obj/t2s_gen'
            # if i==0:
            #     create_path(out_path)
            #     numpy.save(os.path.join(out_path,'output.npy'),outputs.data.cpu().numpy())
            #     numpy.save(os.path.join(out_path,'tgt.npy'),tgt[:,1:,:,:].data.cpu().numpy())
            #     print("Save complete")

            # measure elapsed time
            recoder.tok()
            recoder.tik()
            recoder.data_tik()

            # update average value
            vals = [loss.item()]
            N = input.size(0)
            recoder.update(vals, count=N)

            if i == 0 or i % log_interval == log_interval - 1:
                recoder.log(epoch, i, len(testloader), mode='Test')

    return avg_loss.avg
Ejemplo n.º 16
0
def eval_c3d(model, criterion, valloader, 
        device, epoch, log_interval, writer, eval_samples):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    avg_top1 = AverageMeter()
    avg_top5 = AverageMeter()
    averagers = [losses, avg_top1, avg_top5]
    names = ['val loss','val top1','val top5']
    recoder = Recorder(averagers,names,writer,batch_time,data_time)
    # Set evaluation mode
    model.eval()

    recoder.tik()
    recoder.data_tik()
    for i, batch in enumerate(valloader):
        with torch.no_grad():
            # Reduce the evaluation time !!!
            if i>eval_samples: break
            # measure data loading time
            recoder.data_tok()

            # get the data and labels
            data,lab = [_.to(device) for _ in batch]

            # forward
            outputs = model(data)

            # compute the loss
            loss = criterion(outputs,lab)

            # compute the metrics
            top1, top5 = accuracy(outputs, lab, topk=(1,5))

        # measure elapsed time
        recoder.tok()
        recoder.tik()
        recoder.data_tik()

        # update average value
        vals = [loss.item(),top1,top5]
        recoder.update(vals)

        # logging
        if i==0 or i % log_interval == log_interval-1 or i==len(valloader)-1:
            recoder.log(epoch,i,len(valloader),mode='Eval')
        
    return recoder.get_avg('val top1')
Ejemplo n.º 17
0
def eval_cnn(model, criterion, valloader, 
        device, epoch, log_interval, writer, args):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    avg_acc = AverageMeter()
    averagers = [losses, avg_acc]
    names = ['val loss','val acc']
    recoder = Recorder(averagers,names,writer,batch_time,data_time)
    # Set evaluation mode
    model.eval()

    recoder.tik()
    recoder.data_tik()
    for i, batch in enumerate(valloader):
        with torch.no_grad():
            # measure data loading time
            recoder.data_tok()

            # get the data and labels
            data,lab = [_.to(device) for _ in batch]

            p = args.shot * args.test_way
            data_shot = data[:p]
            data_query = data[p:]
            input = torch.cat([data_shot,data_query],0)

            # forward
            outputs = model(input)

            # compute the loss
            loss = criterion(outputs,lab)

            # compute the metrics
            acc = accuracy(outputs, lab)[0]

            # measure elapsed time
            recoder.tok()
            recoder.tik()
            recoder.data_tik()

        # update average value
        vals = [loss.item(),acc]
        recoder.update(vals)

        # logging
        if i==0 or i % log_interval == log_interval-1 or i==len(valloader)-1:
            recoder.log(epoch,i,len(valloader),mode='Eval')
        
    return recoder.get_avg('val acc')
Ejemplo n.º 18
0
def train_text2sign(model, criterion, optimizer, trainloader, device, epoch,
                    log_interval, writer):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    avg_loss = AverageMeter()
    # Set trainning mode
    model.train()
    # Create recorder
    averagers = [avg_loss]
    names = ['train loss']
    recoder = Recorder(averagers, names, writer, batch_time, data_time)

    recoder.tik()
    recoder.data_tik()
    for i, data in enumerate(trainloader):
        # measure data loading time
        recoder.data_tok()

        # get the inputs and labels
        # shape of input is N x T
        # shape of tgt is N x T2 x J x D
        input, tgt = data['input'].to(device), data['tgt'].to(device)

        optimizer.zero_grad()
        # forward
        outputs = model(input, tgt)

        # compute the loss
        # tgt = pack_padded_sequence(tgt,tgt_len_list)
        loss = criterion(outputs, tgt[:, 1:, :, :])
        # backward & optimize
        loss.backward()

        optimizer.step()

        # measure elapsed time
        recoder.tok()
        recoder.tik()
        recoder.data_tik()

        # update average value
        vals = [loss.item()]
        N = input.size(0)
        recoder.update(vals, count=N)

        if i == 0 or i % log_interval == log_interval - 1:
            recoder.log(epoch, i, len(trainloader))
            # Reset average meters
            recoder.reset()
Ejemplo n.º 19
0
def test_gcr(model, criterion,
          valloader, device, epoch, 
          log_interval, writer, args, relation):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    avg_loss = AverageMeter()
    avg_acc = AverageMeter()
    # Create recorder
    averagers = [avg_loss, avg_acc]
    names = ['val loss', 'val acc']
    recoder = Recorder(averagers,names,writer,batch_time,data_time)
    # Set evaluation mode
    model.eval()

    recoder.tik()
    recoder.data_tik()
    for i, batch in enumerate(valloader):
        with torch.no_grad():
            # measure data loading time
            recoder.data_tok()

            # get the inputs and labels
            data, lab = [_.to(device) for _ in batch]

            # forward
            proto = model.baseModel(data)
            global_set = torch.cat([model.global_base,model.global_novel])
            logits = relation(proto,global_set)
            # print('logits: ',logits.argmax(-1))
            # print('lab: ',lab)
            
            # compute the loss
            loss = criterion(logits, lab)

            # compute the metrics
            acc = accuracy(logits, lab)[0]

            # measure elapsed time
            recoder.tok()
            recoder.tik()
            recoder.data_tik()

        # update average value
        vals = [loss.item(),acc]
        recoder.update(vals)

        if i % log_interval == log_interval-1:
            recoder.log(epoch,i,len(valloader),mode='Test')

    return recoder.get_avg('val acc')
Ejemplo n.º 20
0
def test_one_epoch(model, criterion, testloader, device, epoch, log_interval,
                   writer):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    avg_loss = AverageMeter()
    # Set eval mode
    model.eval()
    # Create recorder
    averagers = [avg_loss]
    names = ['test loss']
    recoder = Recorder(averagers, names, writer, batch_time, data_time)

    recoder.tik()
    recoder.data_tik()
    with torch.no_grad():
        for i, data in enumerate(testloader):
            # measure data loading time
            recoder.data_tok()

            # get the inputs
            q, p = [x.to(device) for x in data]

            # forward
            outputs = model(q)

            # compute the loss
            loss = criterion(outputs, q)

            # measure elapsed time
            recoder.tok()
            recoder.tik()
            recoder.data_tik()

            # update average value
            vals = [loss.item()]
            N = q.size(0)
            recoder.update(vals, count=N)

            if i == 0 or i % log_interval == log_interval - 1:
                recoder.log(epoch, i, len(testloader), mode='Test')

    return avg_loss.avg
Ejemplo n.º 21
0
def eval_mn_pn(model, criterion,
          valloader, device, epoch, 
          log_interval, writer, args):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    avg_loss = AverageMeter()
    avg_acc = AverageMeter()
    statistic = []
    # Create recorder
    averagers = [avg_loss,avg_acc]
    names = ['val loss','val acc']
    recoder = Recorder(averagers,names,writer,batch_time,data_time)
    # Set evaluation mode
    model.eval()

    recoder.tik()
    recoder.data_tik()
    for i, batch in enumerate(valloader):
        with torch.no_grad():
            # measure data loading time
            recoder.data_tok()

            # get the inputs and labels
            data, lab = [_.to(device) for _ in batch]

            # forward
            p = args.shot * args.test_way
            data_shot = data[:p]
            data_query = data[p:]

            y_pred, label = model(data_shot,data_query,mode='eval')
            # print('lab: {}'.format(lab.view((args.shot+args.query_val),args.test_way)[0]))
            # compute the loss
            loss = criterion(y_pred, label)
            # print('y_pred: {}'.format(y_pred.argmax(-1)))
            # print('label: {}'.format(label))

            # compute the metrics
            acc = accuracy(y_pred, label)[0]

            # measure elapsed time
            recoder.tok()
            recoder.tik()
            recoder.data_tik()

        # update average value & account statistic
        vals = [loss.item(),acc]
        recoder.update(vals)
        statistic.append(acc.data.cpu().numpy())

        if i % log_interval == log_interval-1:
            recoder.log(epoch,i,len(valloader),mode='Eval')

    return recoder.get_avg('val acc'), numpy.array(statistic)
Ejemplo n.º 22
0
        t.manual_seed(args.seed)
        cudnn.deterministic = True

    vocab_path = os.path.join(data_dir, "vocab")
    # 根据常见的词典,得到映射表
    word2id, id2word, vocab = initialize_by_vocabulary(vocab_path)

    # 如果不存在目标路径,则创建
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)

    # 创建recorder,记录训练信息
    if (args.test):
        writer = SummaryWriter(
            f"{args.result_dir}/test_{args.project}__{args.timestamp}")
    else:
        writer = SummaryWriter(
            f"{args.result_dir}/train_{args.project}__{args.timestamp}")
    recorder = Recorder(args, writer, id2word)

    # 打印参数
    # for arg in vars(args):
    #  print (arg, getattr(args, arg))
    print(args)
    if not args.test:
        train(args, len(vocab))
    else:
        test(args, len(vocab))
Ejemplo n.º 23
0
def test_mn(model, global_proto, criterion,
          valloader, device, epoch, 
          log_interval, writer, args):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    avg_loss = AverageMeter()
    avg_acc = AverageMeter()
    statistic = []
    # Create recorder
    averagers = [avg_loss,avg_acc]
    names = ['val loss','val acc']
    recoder = Recorder(averagers,names,writer,batch_time,data_time)
    # Set evaluation mode
    model.eval()

    recoder.tik()
    recoder.data_tik()
    for i, batch in enumerate(valloader):
        with torch.no_grad():
            # measure data loading time
            recoder.data_tok()

            # get the inputs and labels
            data, lab = [_.to(device) for _ in batch]

            # forward
            y_pred = model.gfsl_test(global_proto,data)
            # compute the loss
            loss = criterion(y_pred, lab)

            # compute the metrics
            acc = accuracy(y_pred, lab)[0]

            # measure elapsed time
            recoder.tok()
            recoder.tik()
            recoder.data_tik()

        # update average value & account statistic
        vals = [loss.item(),acc]
        recoder.update(vals)

        if i % log_interval == log_interval-1:
            recoder.log(epoch,i,len(valloader),mode='Test')

    return recoder.get_avg('val acc')
Ejemplo n.º 24
0
def train_c3d(model, criterion, optimizer, trainloader, device, epoch,
              log_interval, writer):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    avg_top1 = AverageMeter()
    avg_top5 = AverageMeter()
    # Create recorder
    averagers = [losses, avg_top1, avg_top5]
    names = ['train loss', 'train top1', 'train top5']
    recoder = Recorder(averagers, names, writer, batch_time, data_time)
    # Set trainning mode
    model.train()

    recoder.tik()
    recoder.data_tik()
    for i, batch in enumerate(trainloader):
        # measure data loading time
        recoder.data_tok()

        # get the data and labels
        data, lab = [_.to(device) for _ in batch]

        optimizer.zero_grad()
        # forward
        outputs = model(data)

        # compute the loss
        loss = criterion(outputs, lab)

        # backward & optimize
        loss.backward()
        optimizer.step()

        # compute the metrics
        top1, top5 = accuracy(outputs, lab, topk=(1, 5))

        # measure elapsed time
        recoder.tok()
        recoder.tik()
        recoder.data_tik()

        # update average value
        vals = [loss.item(), top1, top5]
        recoder.update(vals)

        # logging
        if i == 0 or i % log_interval == log_interval - 1:
            recoder.log(epoch, i, len(trainloader))
            # Reset average meters
            recoder.reset()
Ejemplo n.º 25
0
def main():
    # utils variable
    global args, logger, writer, dataset_configs
    # statistics variable
    global best_accuracy, best_accuracy_epoch
    best_accuracy, best_accuracy_epoch = 0, 0
    # configs
    dataset_configs = get_and_save_args(parser)
    parser.set_defaults(**dataset_configs)
    args = parser.parse_args()
    # select GPUs
    os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    '''Create snapshot_pred dir for copying code and saving models '''
    if not os.path.exists(args.snapshot_pref):
        os.makedirs(args.snapshot_pref)

    if os.path.isfile(args.resume):
        args.snapshot_pref = os.path.dirname(args.resume)

    logger = Prepare_logger(args, eval=args.evaluate)

    if not args.evaluate:
        logger.info(f'\nCreating folder: {args.snapshot_pref}')
        logger.info('\nRuntime args\n\n{}\n'.format(
            json.dumps(vars(args), indent=4)))
    else:
        logger.info(
            f'\nLog file will be save in a {args.snapshot_pref}/Eval.log.')
    '''Dataset'''
    train_dataloader = DataLoader(AVEDataset('./data/', split='train'),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=8,
                                  pin_memory=True)

    test_dataloader = DataLoader(AVEDataset('./data/', split='test'),
                                 batch_size=args.test_batch_size,
                                 shuffle=False,
                                 num_workers=8,
                                 pin_memory=True)
    '''model setting'''
    mainModel = main_model()
    mainModel = nn.DataParallel(mainModel).cuda()
    learned_parameters = mainModel.parameters()
    optimizer = torch.optim.Adam(learned_parameters, lr=args.lr)
    # scheduler = StepLR(optimizer, step_size=40, gamma=0.2)
    scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.5)
    criterion = nn.BCEWithLogitsLoss().cuda()
    criterion_event = nn.CrossEntropyLoss().cuda()
    '''Resume from a checkpoint'''
    if os.path.isfile(args.resume):
        logger.info(f"\nLoading Checkpoint: {args.resume}\n")
        mainModel.load_state_dict(torch.load(args.resume))
    elif args.resume != "" and (not os.path.isfile(args.resume)):
        raise FileNotFoundError
    '''Only Evaluate'''
    if args.evaluate:
        logger.info(f"\nStart Evaluation..")
        validate_epoch(mainModel,
                       test_dataloader,
                       criterion,
                       criterion_event,
                       epoch=0,
                       eval_only=True)
        return
    '''Tensorboard and Code backup'''
    writer = SummaryWriter(args.snapshot_pref)
    recorder = Recorder(args.snapshot_pref, ignore_folder="Exps/")
    recorder.writeopt(args)
    '''Training and Testing'''
    for epoch in range(args.n_epoch):
        loss = train_epoch(mainModel, train_dataloader, criterion,
                           criterion_event, optimizer, epoch)

        if ((epoch + 1) % args.eval_freq == 0) or (epoch == args.n_epoch - 1):
            acc = validate_epoch(mainModel, test_dataloader, criterion,
                                 criterion_event, epoch)
            if acc > best_accuracy:
                best_accuracy = acc
                best_accuracy_epoch = epoch
                save_checkpoint(
                    mainModel.state_dict(),
                    top1=best_accuracy,
                    task='Supervised',
                    epoch=epoch + 1,
                )
        scheduler.step()