def predict(args, model, data_loader, distance, hall):

    with torch.no_grad():
        model.eval()
        hall.eval()
        # print(model)
        # print(hall)
        # each batch represent one episode (support data + query data)
        for i, (data, target) in enumerate(data_loader):
            with open(args.output_csv, 'a') as f:
                f.write("{}".format(i))
            # split data into support and query data
            support_input = data[:args.n_way * args.n_shot].to(args.device)
            query_input = data[args.n_way * args.n_shot:].to(args.device)

            proto = model(support_input).view(args.n_way, args.n_shot, -1)
            new_proto = torch.empty(
                [args.n_way, args.n_shot+args.n_aug, args.dim]).to(args.device)
            for c in range(args.n_way):
                fake = hall(proto[c][0])
                new_proto[c] = torch.cat([proto[c], fake], dim=0)
            new_proto = new_proto.mean(1)

            feature = model(query_input)
            distance = Distance(args)
            logits = distance(new_proto, feature)
            preds = torch.argmax(logits, dim=1)
            with open(args.output_csv, 'a') as f:
                for pred in preds:
                    f.write(",{}".format(pred))
                f.write("\n")
def val(data_loader, model, hall, criterion, args):
    with torch.no_grad():
        model.eval()
        total_loss = []
        total_acc = []
        for _, data in enumerate(data_loader):
            image, _ = data
            support = {'image': image[:args.n_way*args.n_shot].to(args.device),
                       'label': torch.LongTensor([i//args.n_shot for i in range(args.n_way*args.n_shot)])}
            query = {'image': image[args.n_way*args.n_shot:].to(args.device),
                     'label': torch.LongTensor([i//args.n_query for i in range(args.n_way*args.n_query)]).to(args.device)}
            proto = model(support['image']).view(args.n_way, args.n_shot, -1)
            new_proto = torch.empty(
                [args.n_way, args.n_shot+args.n_aug, args.dim]).to(args.device)
            for c in range(args.n_way):
                fake = hall(proto[c][0])
                new_proto[c] = torch.cat([proto[c], fake], dim=0)
            new_proto = new_proto.mean(1)

            feature = model(query['image'])
            distance = Distance(args)
            logits = distance(new_proto, feature)
            loss = criterion(logits, query['label'])
            total_loss.append(loss.item())
            accuracy = calculate_acc(logits, query['label'])
            total_acc.append(accuracy)

    print('Validation: loss {:.3f}, acc {:.3f}\n'.format(
        np.mean(total_loss), np.mean(total_acc)))
    return np.mean(total_loss), np.mean(total_acc)
def train(data_loader, model, hall, criterion, optimzer, args):
    model.train()
    total_loss = []
    total_acc = []
    for step, data in enumerate(data_loader):
        image, _ = data
        support = {'image': image[:args.n_way*args.n_shot].to(args.device),
                   'label': torch.LongTensor([i//args.n_shot for i in range(args.n_way*args.n_shot)])}
        query = {'image': image[args.n_way*args.n_shot:].to(args.device),
                 'label': torch.LongTensor([i//args.n_query for i in range(args.n_way*args.n_query)]).to(args.device)}
        proto = model(support['image']).view(args.n_way, args.n_shot, -1)
        new_proto = torch.empty(
            [args.n_way, args.n_shot+args.n_aug, args.dim]).to(args.device)
        for c in range(args.n_way):
            fake = hall(proto[c][0])
            new_proto[c] = torch.cat([proto[c], fake], dim=0)
        new_proto = new_proto.mean(1)

        feature = model(query['image'])
        distance = Distance(args)
        logits = distance(new_proto, feature)
        loss = criterion(logits, query['label'])
        total_loss.append(loss.item())
        accuarcy = calculate_acc(logits, query['label'])
        total_acc.append(accuarcy)

        optimzer['m'].zero_grad()
        optimzer['h'].zero_grad()
        loss.backward()
        optimzer['m'].step()
        optimzer['h'].step()
        if step % 50 == 0:
            print('step {}: loss {:.3f}, acc {:.3f}'.format(
                step, np.mean(total_loss), np.mean(total_acc)), end='\r')
    print('Training: loss {:.3f}, acc {:.3f}'.format(
        np.mean(total_loss), np.mean(total_acc)), end='\n')
    return parser.parse_args()


if __name__ == '__main__':
    args = parse_args()
    set_seed(123)
    test_dataset = MiniDataset(args.test_csv, args.test_data_dir)

    test_loader = DataLoader(test_dataset,
                             batch_size=args.n_way *
                             (args.n_query + args.n_shot),
                             num_workers=3,
                             pin_memory=False,
                             worker_init_fn=worker_init_fn,
                             sampler=TestSampler(args.testcase_csv))

    # TODO: load your model
    model = Convnet4(out_channels=args.dim).to(args.device)
    model.load_state_dict(torch.load(args.model))
    distance = Distance(args)
    if args.distance_type == 'param':
        assert args.param
        distance.param.load_state_dict(torch.load(args.param))
    with open(args.output_csv, 'w') as f:
        f.write('episode_id')
        for i in range(75):
            f.write(',query{}'.format(i))
        f.write('\n')
    predict(args, model, test_loader, distance)
Exemple #5
0
def train(data_loader, model, hall, disc, criterion, optimzer, args):
    model.train()
    total_loss = []
    total_loss_d = []
    total_loss_g = []
    total_acc = []
    for step, data in enumerate(data_loader):
        image, _ = data
        support = {
            'image':
            image[:args.n_way * args.n_shot].to(args.device),
            'label':
            torch.LongTensor(
                [i // args.n_shot for i in range(args.n_way * args.n_shot)])
        }
        query = {
            'image':
            image[args.n_way * args.n_shot:].to(args.device),
            'label':
            torch.LongTensor([
                i // args.n_query for i in range(args.n_way * args.n_query)
            ]).to(args.device)
        }
        proto = model(support['image']).view(args.n_way, args.n_shot, -1)

        # Train Model
        new_proto = torch.empty(
            [args.n_way, args.n_shot + args.n_aug, args.dim]).to(args.device)
        for c in range(args.n_way):
            fake = hall(proto[c][0])
            new_proto[c] = torch.cat([proto[c], fake], dim=0)
        new_proto = new_proto.mean(1)

        feature = model(query['image'])
        distance = Distance(args)
        logits = distance(new_proto, feature)
        loss = criterion(logits, query['label'])
        total_loss.append(loss.item())
        accuarcy = calculate_acc(logits, query['label'])
        total_acc.append(accuarcy)

        optimzer['m'].zero_grad()
        optimzer['h'].zero_grad()
        loss.backward(retain_graph=True)
        optimzer['m'].step()
        optimzer['h'].step()

        # Train Discriminator
        loss_real = 0
        loss_fake = 0
        for c in range(args.n_way):
            fake = hall(proto[c][0])
            loss_real += torch.mean(disc(proto[c]))
            loss_fake += torch.mean(disc(fake))
        loss_d = -loss_real + loss_fake
        total_loss_d.append(loss_d.item())

        optimzer['d'].zero_grad()
        loss_d.backward(retain_graph=True)
        optimzer['d'].step()
        disc.weight_cliping()

        # Train Generator
        if step % 5 == 0:
            loss_g = 0
            for c in range(args.n_way):
                fake = hall(proto[c][0])
                loss_g += -torch.mean(disc(fake))
            total_loss_g.append(loss_g.item())

            optimzer['h'].zero_grad()
            loss_g.backward()
            optimzer['h'].step()

        if step % 50 == 0:
            print(
                'step {}: loss_D {:.3f}, loss_G {:.3f}, loss {:.3f}, acc {:.3f}'
                .format(step, np.mean(total_loss_d), np.mean(total_loss_g),
                        np.mean(total_loss), np.mean(total_acc)),
                end='\r')
    print('Training: loss_D {:.3f}, loss_G {:.3f}, loss {:.3f}, acc {:.3f}'.
          format(np.mean(total_loss_d), np.mean(total_loss_g),
                 np.mean(total_loss), np.mean(total_acc)),
          end='\n')